apex
apex copied to clipboard
Support for arbitrary sequence length in FMHA
We experimented with porting existing fmha implementation to efficiently support arbitrary sequence length for head_dim=64/128 and seem to have found a working solution by computing blocks one at a time and accumulating results.
Here's our approach and benchmarks: https://gist.github.com/krunt/72197074816dfe4035fcd9413e4afb22
Here's our implementation (patch to apex): https://github.com/krunt/apex/compare/727a6452c9b781930acee5e24e09efe9360b4890...3655f21606256efdcae092a0b4fd0ac1151d2de8 You can also find the implementation for head dim 128 in the above gist. Here's how to run/benchmark it: https://github.com/krunt/apex/blob/arbitlen_fmha_headdim_64/run_fmha.py
Authors: @krunt, @TimDettmers, @xtinkt We also thank @yjk21 @jdemouth and @jaredcasper for helpful discussions Almost the same idea independently proposed in: https://arxiv.org/abs/2205.14135 Similar idea with full re-materialization proposed in: https://arxiv.org/abs/2112.05682
CC @tridao
This is great! Our forward pass implementations are very similar. The backward looks different (https://github.com/HazyResearch/flash-attention) but let's figure that out during our chat!