apex icon indicating copy to clipboard operation
apex copied to clipboard

Support for arbitrary sequence length in FMHA

Open krunt opened this issue 2 years ago • 2 comments

We experimented with porting existing fmha implementation to efficiently support arbitrary sequence length for head_dim=64/128 and seem to have found a working solution by computing blocks one at a time and accumulating results.

Here's our approach and benchmarks: https://gist.github.com/krunt/72197074816dfe4035fcd9413e4afb22

Here's our implementation (patch to apex): https://github.com/krunt/apex/compare/727a6452c9b781930acee5e24e09efe9360b4890...3655f21606256efdcae092a0b4fd0ac1151d2de8 You can also find the implementation for head dim 128 in the above gist. Here's how to run/benchmark it: https://github.com/krunt/apex/blob/arbitlen_fmha_headdim_64/run_fmha.py

Authors: @krunt, @TimDettmers, @xtinkt We also thank @yjk21 @jdemouth and @jaredcasper for helpful discussions Almost the same idea independently proposed in: https://arxiv.org/abs/2205.14135 Similar idea with full re-materialization proposed in: https://arxiv.org/abs/2112.05682

krunt avatar May 31 '22 10:05 krunt

CC @tridao

justheuristic avatar May 31 '22 11:05 justheuristic

This is great! Our forward pass implementations are very similar. The backward looks different (https://github.com/HazyResearch/flash-attention) but let's figure that out during our chat!

tridao avatar May 31 '22 17:05 tridao