Tri Dao
Tri Dao
I've been trying to make this work, but I'm not experienced with torch.compile. If you figure sth out please let me know.
Thanks for the investigation. So right now sounds like it's hard to do in-place ops with torch.compile.
Sure, can you point me to how they do it?
There's plan but it'll take a while.
No we don't commit to a public timeline. It really depends on how much folks are contributing their time
We're working on Blackwell
It's coming along. Meanwhile you can use either cuDNN or the cutlass implementation: https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
We're building on the cute-dsl example here: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/fmha.py If you'd like to help, you can start porting the backward pass from C++ to Cute-DSL: https://github.com/NVIDIA/cutlass/tree/main/examples/77_blackwell_fmha
Yes we plan to support aarch64 (because of GB200). Currently cute-dsl doesn't have a wheel on aarch64 yet (only x86_64) but that will be fixed soon