pytorch
pytorch copied to clipboard
Add support for memory efficient attention for AMD/ROCm
🚀 The feature, motivation and pitch
Enable support for Flash Attention Memory Efficient and SDPA kernels for AMD GPUs.
At present using these gives below warning with latest nightlies (torch==2.4.0.dev20240413+rocm6.0, pytorch-triton-rocm 3.0.0+0a22a91d04):
/site-packages/diffusers/models/attention_processor.py:1117: UserWarning: 1Torch was not compiled with memory efficient attention. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:505.)
Alternatives
Users cannot use the native PyTorch APIs with memory efficient attention.
Additional context
No response
Hi,
Not sure what is the status, but looks like AMD has been working on it: https://github.com/pytorch/pytorch/pull/114309