Tri Dao

Results 472 comments of Tri Dao
trafficstars

I see, I've just pushed a commit to avoid name collision. At some point FA3 will replace FA2 but for now they should be able to co-exist.

I should clarify in the docs that `seqused` is only for the cases when you know what you're doing (e.g. cache length during decoding, or splitting when doing context parallel)....

Do you have an example what what `seqused_q` and `seqused_k` look like? And what are the shapes of q, k, v? I might add an option to zero out the...

My original hypothesis is that the memory locations in the output & gradients corresponding to unused tokens are uninitialized. You might get lucky and most of the time these locations...

Backprop on softmax_lse is not supported. Feel free to work on it if you need it. You just have to work out the gradient and then implement it. I suspect...

It depends on what the gradient looks like. What's the gradient for softmax_lse?

Well you need to work out how to compute the gradient mathematically (e.g see FlashAttention paper appendix B2) before implementing it.

Bwd pass code is here: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h you can follow acc_dk and acc_dq to see how it's being computed right now.

Yup exp is one of the bottlenecks. We talked about that a bit in the FA3 paper.

exp uses the MUFU (multi-function unit), which has quite low throughput (e.g. 16 ops per clock cycle, which is 4-8x lower than add / mul floating point operations). https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#arithmetic-instructions