Tri Dao
Tri Dao
Can you give a short script showing the numerical error?
Sure, we'll just need someone to contribute :D
2080 (Turing) is not supported in the latest version.
Thanks so much for your work @skrider. Can you rebase and then I'll merge?
Yep, we have new wheels compiled for pytorch 2.3.0
Wheels are built for torch 2.2.2 and torch 2.3.0. Looks like it's not compatible with 2.2.0. You can try previous version of flash-attn, or build from scratch.
Unfortunately I haven't had much bandwidth.
Turing cards have less shared memory (64KB instead of 99KB or 163KB on Ampere) so that might require adjusting the block sizes currently used.
`make sure nvcc has a supported version by running nvcc -V`
Can you save the tensors being passed to flash_attn_cuda.varlen_bwd and send them to me? Otherwise it would be very hard to debug? And can you print out the value of...