Tri Dao

Results 250 comments of Tri Dao

Can you add a short script to reproduce the error?

Idk much about the ROCm version, you might have better luck asking on their repo.

Thanks for the report, I can reproduce it. Investigating now. Might be because of the way torch (in C++) handle dtype.

Hmm compiling from scratch seems to work fine, so sth is wrong about the wheel we built.

I'm guessing this is because 24.03 uses CUDA 12.4 and the wheels built with nvcc 12.2 are somehow not compatible.

What about just adding a key that's all zero before calling attention?

> I'd like to apply this to a large-scale training, so keeping optimal performance is important to me. I'm curious, what kind of improvement do you see at smaller scale?

[This](https://github.com/Dao-AILab/flash-attention/blob/02ac572f3ffc4f402e4183aaa6824b45859d3ed3/csrc/flash_attn/src/flash_fwd_kernel.h#L509) multiplies the output by the inverse of the denominator, so you can add sth to the denominator there. I think the output was already multiplied by exp(-max) before this...

I'm not sure how this works, do you have a suggestion?

Not sure I understand the question but the function docstring should tell you the shapes of the tensors and whether they need to be contiguous: ``` k_cache: (batch_size_cache, seqlen_cache, nheads_k,...