Tri Dao
Tri Dao
If your application is very sensitive to numerical error then flash-attn might not be a good fit, mainly because we only support fp16 / bf16 and not fp32.
Then flash-attn should be more accurate than the standard implementation. You want to compare (flash-attn in bf16 - reference impl in fp32) vs (reference impl in bf16 - reference impl...
The error seems too high, you can try `flash_attn_func` since it's simpler to call (no need to construct cu_seqlens which might be error prone). Try to make the test as...
I haven't tried with pytorch 2.2.2 but I don't see why compiling from source wouldn't work. The wheel may or may not be compatible.
looks like it's still downloading the wheel? Can you try `python3 setup.py install`?
We have new wheels (flash-attn 2.5.8) that should work with torch 2.2.2
Most of the reads/writes are coalesced. There are some small writes (e.g. writing to the LSE) that are not, but I don't think it matters. Lmk if you profile more...
`pip install flash-attn==1.0.9`
sure you can try that
I don't quite understand your algorithm, can you add some pseudo code?