triton-transformer
triton-transformer copied to clipboard
Official layer norm added
Hi @lucidrains , in Triton layer norm was just added in examples, https://github.com/openai/triton/commit/d4baad426db72b83c5222e1c83c929c1860cae54 I tested it, it's twice as fast as Torch, often faster then Apex.
I'm looking forward for your implementation of attention, so far the Torch implementation is the fastest with 12.3 / 14.5 (forw / back) vs the other Triton implementation in DeepSpeed which is 17.3/ 23.0 on my data.
@olegklimov Hi Oleg, you should try the open sourced softmax and blocksparse matmul by the author of Triton https://github.com/openai/triton/tree/master/python/triton/ops/blocksparse I will probably only focus on a simplified variant aimed to make autoregressive attention fast, and not much else
I actually tried that! On my 4·2048·1536 float16 batch,
Dense attention built on top of blocksparse matmul and softmax:
196.8ms + 264.1ms = 460.9ms (forw and back, no checkpointing, 12 layers) 2.35 gb activations per example
Torch implementationfunctional.multi_head_attention_forward
:
157.6ms + 170.8ms = 328.5ms 4.29 gb activations per example
So it starts to look kind of strange, for dense layers it's better to use the torch version, for sparse layers this triton version. But then there's memory. Too complex 😵