triton
triton copied to clipboard
Flash Attention 3 --> Triton
Flash attention 3 makes use of new features of the Hopper architecture.
- (async) WGMMA
- TMA
- overlap softmax
Are these all things that can currently (or in the future) be optimized automatically with the triton compiler? And could the fused attention implementation from https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html make use of these without changes?
Triton OSS currently uses WGMMA (async). TMA is still experimental, but work is ongoing for improving the descriptors. For computation overlapping, I am actually trying to see if we can modify the existing SWP to enable specifying stages/clusters for computation ops.
@manman-ren @jenkspt WGMMA is now supported in triton IR.
We have implemented optimizations (warp specialization, computation pipelining) in this branch https://github.com/facebookexperimental/triton/blob/mren/ws-comp-pipeline/README.flash
Also see our latest blog post: https://pytorch.org/blog/warp-specialization/