triton icon indicating copy to clipboard operation
triton copied to clipboard

Flash Attention 3 --> Triton

Open jenkspt opened this issue 1 year ago • 1 comments

Flash attention 3 makes use of new features of the Hopper architecture.

  • (async) WGMMA
  • TMA
  • overlap softmax

Are these all things that can currently (or in the future) be optimized automatically with the triton compiler? And could the fused attention implementation from https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html make use of these without changes?

jenkspt avatar Jul 11 '24 22:07 jenkspt

Triton OSS currently uses WGMMA (async). TMA is still experimental, but work is ongoing for improving the descriptors. For computation overlapping, I am actually trying to see if we can modify the existing SWP to enable specifying stages/clusters for computation ops.

manman-ren avatar Jul 15 '24 19:07 manman-ren

@manman-ren @jenkspt WGMMA is now supported in triton IR.

We have implemented optimizations (warp specialization, computation pipelining) in this branch https://github.com/facebookexperimental/triton/blob/mren/ws-comp-pipeline/README.flash

Also see our latest blog post: https://pytorch.org/blog/warp-specialization/

manman-ren avatar Feb 10 '25 17:02 manman-ren