ring-flash-attention
ring-flash-attention copied to clipboard
Ring attention implementation with flash attention
Ring Flash Attention
This repo implements the RingAttention with FlashAttention. Currently, this repo implements:
-
ring_flash_attn_func
: ring attention version offlash_attn_func
-
ring_flash_attn_varlen_func
: ring attention version offlash_attn_varlen_func
-
zigzag_ring_flash_attn_func
: an optimized version ofring_flash_attn_func
, see issue#2 -
zigzag_ring_flash_attn_varlen_func
: an optimized version ofring_flash_attn_varlen_func
-
stripe_flash_attn_func
: stripe attention version ofring_flash_attn_func
, the block size is set to 1 to use flash_attn api.
Note that
- all function has the
*_func
,*_kvpacked_func
,*_qkvpacked_func
variant implemented. - the varlen versions only support passing one
cu_seqlens
.
The main idea is to use the softmax_lse
output from the flash attention kernels.
The current performance on 8xH800 is (benchmark/benchmark_qkvpacked_func.py):
GPU | theoretic flash_attn | ring_attn | zigzag_ring | stripe_attn | |
---|---|---|---|---|---|
fwd only (iter/sec) | 8xH800 | 2418.4 / 8 = 302.3 | 208.0 | 283.0 | 259.6 |
68.8% | 93.6% | 85.9% | |||
fwd + bwd (iter/sec) | 8xH800 | 705.2 / 8 = 88.2 | 54.3 | 75.7 | 76.9 |
61.5% | 85.9% | 87.2% | |||
fwd only (iter/sec) | 8xA100 | 1545.9 / 8 = 193.2 | 124.4 | 179.0 | 163.9 |
64.3% | 92.7% | 84.8% | |||
fwd + bwd (iter/sec) | 8xA100 | 470.6 / 8 = 58.8 | 33.3 | 49.5 | 45.9 |
56.6% | 84.1% | 78.1% |
Note that
- when running the benchmark with with 8 gpu, the flash attn code is running with 1/8 computation of ring attention.
- nvlink between GPUs are required for high performance.
- the varlen versions are slow at the moment, please use the non-varlen version if possible.
Limits
There are some arithmetic errors with the current implementation. The reason for them is probably that flash attention will return bf16 value for each block, so we cannot accumluate the values with the original fp32 ones.
And also because we need to save extra fp32 buffer during computation, the memory usage would be higher than theoretic limit.
TODOs
- [x] Implement
ring_flash_attn_varlen_qkvpacked_func
- [x] Implement
zigzag_ring_flash_attn_qkvpacked_func
issue#2 - [x] Implement
stripe_flash_attn_qkvpacked_func
- [x] Implement
zigzag_ring_flash_attn_varlen_qkvpacked_func
- [x] Implement
*_kvpacked_func
and*_func
variant for all APIs - [ ] Optimize
*_varlen_func
- [ ] Try to upstream to flash attention.
Test
torchrun --nproc_per_node 8 test/test_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_func.py
torchrun --nproc_per_node 8 test/test_zigzag_ring_flash_attn_varlen_func.py
torchrun --nproc_per_node 8 test/test_stripe_flash_attn_func.py
Benchmark
torchrun --nproc_per_node 8 benchmark/benchmark_qkvpacked_func.py
torchrun --nproc_per_node 8 benchmark/benchmark_varlen_qkvpacked_func.py
Known Limits
- dropout is not supported at the moment, because it's hard to save all the rng_states.
- window_size is not supported, because it will be really tricky to implement a varlen version with window_size.