Tri Dao
Tri Dao
Can you profile to get the time for the attention kernel?
Can you profile and post how long the FA2 kernel and FA3 kernel take? and what are the input shapes?
I see you're using hdim 32. FA2 does have a implementation for hdim 32 but FA3 does not specialize for hdim 32, instead round it to hdim 64. You can...
It's not hard but we don't plan to specialize for hdim 32 in FA3 to reduce compilation time (it's already taking very long to compile). Most model uses hdim 64...
The code stays the same, just change the dispatching. You want to add to tile_size: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/tile_size.h Then add the new hdim to `generate_kernels` and `setup.py`: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/generate_kernels.py Then add to the...
These are all tuned empirically. You can copy the hdim64 case (e.g. 128 x 112). The tile size 128 x 128 should also work.
That sounds right. You can tune `tile_size.h` but i'm not sure how much that would improve. Hdim 32 is just not very hardware friendly
Soon, 3-4 weeks
It's still a work in progress and not complete yet