Tri Dao

Results 482 comments of Tri Dao

The improvement in the backward pass is a combination of factors: - Not using split-k, so we reduce amount of shared memory needed and shared memory read/write. - Better work...

Thanks for the bug report. I can reproduce the error now.

Yes that's right.

yes q@k^T is in fp32, softmax is done in fp32, then converted to bf16 to do the gemm with V.

All look reasonable, I've no idea why it fails. We recommend the [Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) container from Nvidia, which has all the required tools to install FlashAttention.

Which cuda version are you using?

> [@tridao](https://github.com/tridao) torch:, 2.4.0+cu124 nvcc:, V12.4.131 You should try the latest version. It works fine for me. Btw your sequence lengths aren't right since x has `seqlens*2-1` but you `cu_seqlens_kv`...

> varlen hang happening to me too on `flashattn-hopper==3.0.0b1` (the wheel distributed alongside `flash_attn==2.7.2.post1`). > > using CUDA 12.8, H100, pytorch 2.6.0, driver 535.216.01. > > if seqused is None,...