vectorized activation vm
This is a side-project I'm working on during live-coding sessions on twitch. Primary time for streams is Monday evenings (8:30 CET or so) on https://www.twitch.tv/kalicoding . The unedited videos are also https://www.youtube.com/playlist?list=PL2L_7bZXXxQuNpcxjg8gipegP32pmwVZF .
Landing page for twitch with some explanations is here. https://github.com/kali/coding .
- [x] Activation function inventory and Rust POC (parts 1 and 2)
- [x] Early benchs (part 3)
- [x] Bringing up assembly integration in tract (parts 4, 5, 6)
- [x] Rework the protocol so constant are in the flow of instructions (Rust, assembly) (part 7)
- [x] Jump table generation (part 8)
- [x] Support for "bloc-1" activation (ReLu to HardSwish) (part 8, part 9)
- [x] Test should be one-liner (Rust) (part 8)
- [x] bench (part 10)
- [ ] integrate in tract-core
- [ ] ARMv8.2 f16 support
- [ ] other activations: Sigmoid (part 10), Tanh, Exp2f
- [ ] implement on ARMv7
- [ ] implement on intel (primary target AVX2+FMA)
- [ ] consider compiling instead of interpreting
-
We are interested in Element wise operations:
- result is independent of tensor geometry (only accept scalar operations, no broadcasting of tensors)
- output = f(x) for x in input
-
"simple ops + if/then/else"
- Relu(x) = max(0, x)
- Affine = alpha * x + beta
- LeakyRelu = if x >= 0 { x } else { alpha * x }
- ThresholdRelu = if x >= alpha { x } else { 0 }
- HardSigmoid = min(max(alpha * x + beta), 0, 1)
- Sofsign = x / (1 + abs(x))
- HardSwish = x * max(0(min(1, alpha * x + beta))) (alpha = 1/6 and beta 1/2)
Proposed virtual machine def
- 4 (?) big registers (mapped to one or several hardware vector registers)
- on armv8, one big register could be 4 (or 5?) NEON registers (44 f32 values or 48 f16 values)
- if 16 NEON registers are used, it leaves 16 for housekeeping (pre-fetch caching + operators)
- focusing on the "simple" activation functions
- BigRegs are a,b,c,d
- min, max, abs, +, -, *, /, ite, ifpos
- constants
- calling convention: framework preloads x into a, expects y in a
Moves, load consts to any reg, unary, binary, ternary on a,b,c only
-
we want to limit the combination of op*operands
-
moves: 12 moves ops (from 4 registers to 3)
-
other ops have all fixed registers: unary ops a <- a, binary a <- a#b, ternary a <- a#b#c
-
Relu: load(b, 0) | max
-
Affine: load(b, alpha) | mul | load(v, beta) | add
-
LeakyRelu: b <- a | load(a, alpha) | mul | c <- a | ifpos
-
ThresholdRelu: c <- a | load(b, alpha) | sub | b <- c | load(c, 0) | ifpos
-
Softsign: c <- a | abs | load(b, 1) | add | recip | b <- c | mul
-
HardSwish: load(b, alpha) | mul | load(b, beta) | add | load(b, 1) | min | load(b, 0) | max
Better spec:
-
what if we add: a <- a#K (K constant)
-
Relu: max(0)
-
Affine: mul(alpha) | add(beta)
-
LeakyRelu: b <- a | mul(alpha) | c <- a | ifpos
-
ThresholdRelu: b <- a | sub(alpha) | load(c, 0) | ifpos
-
Softsign: b <- a | abs | add(1) | recip | mul
-
HardSwish: mul(alpha) | add(beta) | min(0) | max(1)
Much better :)
What about other activation functions ?
-
Some are implemented in tract with rational approximations
- P(x)/Q(x) where P and Q are polys of x
- Some can be expressed from e^x:
- Tanh(x) - (1 - e(-2x)/(1 + e(-2x))
- Sigmoid(x) - 1/(1 + e(-x))
- But it is faster to use rational approx. Tanh and Sigmoid have handcrafted vectorized assembly impls.
- Q: Can we extend our VM to support rational approx ?
- Erf can only be computed from rational approx (as it is less used than tanh and sigmoid, it just have a rust implementation).
-
More activation functions computed from exp (and also log, tanh)
- implemented as separate expansion to several tract operators
- should we look for approximation ? rational or poly ?
- can we / should we add e^x as a primitive ?
- can we do as well / better than default impl ?
- ScaledTanh = alpha * tanh(beta * x) (should be easily derived from tanh)
- Softplus = log(1 + e^x)
- Celu = max(0,x) + min(0,alpha * exp(x / alpha) - 1)
- Elu = if x < 0 { alpha * (exp(x) - 1) } else { x }
- Selu = if x < 0 { gamma * (alpha * e^x - alpha) } else { gamma * x }
- Mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 * e^x))