tract icon indicating copy to clipboard operation
tract copied to clipboard

vectorized activation vm

Open kali opened this issue 2 years ago • 1 comments

This is a side-project I'm working on during live-coding sessions on twitch. Primary time for streams is Monday evenings (8:30 CET or so) on https://www.twitch.tv/kalicoding . The unedited videos are also https://www.youtube.com/playlist?list=PL2L_7bZXXxQuNpcxjg8gipegP32pmwVZF .

Landing page for twitch with some explanations is here. https://github.com/kali/coding .

  • [x] Activation function inventory and Rust POC (parts 1 and 2)
  • [x] Early benchs (part 3)
  • [x] Bringing up assembly integration in tract (parts 4, 5, 6)
  • [x] Rework the protocol so constant are in the flow of instructions (Rust, assembly) (part 7)
  • [x] Jump table generation (part 8)
  • [x] Support for "bloc-1" activation (ReLu to HardSwish) (part 8, part 9)
  • [x] Test should be one-liner (Rust) (part 8)
  • [x] bench (part 10)
  • [ ] integrate in tract-core
  • [ ] ARMv8.2 f16 support
  • [ ] other activations: Sigmoid (part 10), Tanh, Exp2f
  • [ ] implement on ARMv7
  • [ ] implement on intel (primary target AVX2+FMA)
  • [ ] consider compiling instead of interpreting

kali avatar Mar 21 '23 22:03 kali

  • We are interested in Element wise operations:

    • result is independent of tensor geometry (only accept scalar operations, no broadcasting of tensors)
    • output = f(x) for x in input
  • "simple ops + if/then/else"

    • Relu(x) = max(0, x)
    • Affine = alpha * x + beta
    • LeakyRelu = if x >= 0 { x } else { alpha * x }
    • ThresholdRelu = if x >= alpha { x } else { 0 }
    • HardSigmoid = min(max(alpha * x + beta), 0, 1)
    • Sofsign = x / (1 + abs(x))
    • HardSwish = x * max(0(min(1, alpha * x + beta))) (alpha = 1/6 and beta 1/2)

Proposed virtual machine def

  • 4 (?) big registers (mapped to one or several hardware vector registers)
    • on armv8, one big register could be 4 (or 5?) NEON registers (44 f32 values or 48 f16 values)
    • if 16 NEON registers are used, it leaves 16 for housekeeping (pre-fetch caching + operators)
  • focusing on the "simple" activation functions
    • BigRegs are a,b,c,d
    • min, max, abs, +, -, *, /, ite, ifpos
    • constants
  • calling convention: framework preloads x into a, expects y in a

Moves, load consts to any reg, unary, binary, ternary on a,b,c only

  • we want to limit the combination of op*operands

  • moves: 12 moves ops (from 4 registers to 3)

  • other ops have all fixed registers: unary ops a <- a, binary a <- a#b, ternary a <- a#b#c

  • Relu: load(b, 0) | max

  • Affine: load(b, alpha) | mul | load(v, beta) | add

  • LeakyRelu: b <- a | load(a, alpha) | mul | c <- a | ifpos

  • ThresholdRelu: c <- a | load(b, alpha) | sub | b <- c | load(c, 0) | ifpos

  • Softsign: c <- a | abs | load(b, 1) | add | recip | b <- c | mul

  • HardSwish: load(b, alpha) | mul | load(b, beta) | add | load(b, 1) | min | load(b, 0) | max

Better spec:

  • what if we add: a <- a#K (K constant)

  • Relu: max(0)

  • Affine: mul(alpha) | add(beta)

  • LeakyRelu: b <- a | mul(alpha) | c <- a | ifpos

  • ThresholdRelu: b <- a | sub(alpha) | load(c, 0) | ifpos

  • Softsign: b <- a | abs | add(1) | recip | mul

  • HardSwish: mul(alpha) | add(beta) | min(0) | max(1)

Much better :)

What about other activation functions ?

  • Some are implemented in tract with rational approximations

    • P(x)/Q(x) where P and Q are polys of x
    • Some can be expressed from e^x:
      • Tanh(x) - (1 - e(-2x)/(1 + e(-2x))
      • Sigmoid(x) - 1/(1 + e(-x))
    • But it is faster to use rational approx. Tanh and Sigmoid have handcrafted vectorized assembly impls.
    • Q: Can we extend our VM to support rational approx ?
    • Erf can only be computed from rational approx (as it is less used than tanh and sigmoid, it just have a rust implementation).
  • More activation functions computed from exp (and also log, tanh)

    • implemented as separate expansion to several tract operators
    • should we look for approximation ? rational or poly ?
    • can we / should we add e^x as a primitive ?
      • can we do as well / better than default impl ?
    • ScaledTanh = alpha * tanh(beta * x) (should be easily derived from tanh)
    • Softplus = log(1 + e^x)
    • Celu = max(0,x) + min(0,alpha * exp(x / alpha) - 1)
    • Elu = if x < 0 { alpha * (exp(x) - 1) } else { x }
    • Selu = if x < 0 { gamma * (alpha * e^x - alpha) } else { gamma * x }
    • Mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 * e^x))

kali avatar Mar 22 '23 06:03 kali