Tri Dao
Tri Dao
The improvement in the backward pass is a combination of factors: - Not using split-k, so we reduce amount of shared memory needed and shared memory read/write. - Better work...
Thanks for the bug report. I can reproduce the error now.
Yes that's right.
yes q@k^T is in fp32, softmax is done in fp32, then converted to bf16 to do the gemm with V.
Can you try with gcc 10?
What's your `nvcc` version?
All look reasonable, I've no idea why it fails. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) container from Nvidia, which has all the required tools to install FlashAttention.
Which cuda version are you using?
> [@tridao](https://github.com/tridao) torch:, 2.4.0+cu124 nvcc:, V12.4.131 You should try the latest version. It works fine for me. Btw your sequence lengths aren't right since x has `seqlens*2-1` but you `cu_seqlens_kv`...
> varlen hang happening to me too on `flashattn-hopper==3.0.0b1` (the wheel distributed alongside `flash_attn==2.7.2.post1`). > > using CUDA 12.8, H100, pytorch 2.6.0, driver 535.216.01. > > if seqused is None,...