composable_kernel
composable_kernel copied to clipboard
WMMA / RDNA3+ kernels for backwards fused attention?
Problem Description
Composable Kernel currently only contains code to support fused attention (FA2) on RDNA3(+) architectures in the forward direction. This greatly increases the VRAM requirements for training LoRAs on LLMs using HuggingFace's Transformers and PEFT libraries - training jobs that succeed on an NVIDIA GeForce RTX 4080 with just 16GB VRAM fail on a Radeon RX 7900 XT with 20GB.
Based on https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal and https://github.com/Repeerc/sd-webui-flash-attention2-rdna3-rocm, it seems possible to implement a usable WMMA-based backwards fused attention kernel - unfortunately I can't use these myself directly, as these are both tailored for image generation (Stable Diffusion), whereas I would be interested in FA2 support for LLM training instead.
Are there any plans for adding fused attention backward pass support for RDNA3+ GPUs to CK in the foreseeable future? This seems especially pressing with the W7900 Dual Slot, an RDNA3 GPU, being recommended for AI workstation usage, where the ability to make effective use of this GPU's 48GB VRAM during training feels a lot more of a core use case.
Operating System
Ubuntu 22.04 LTS
CPU
AMD Ryzen 9 7950X (non-3D)
GPU
AMD Radeon RX 7900 XTX, AMD Radeon Pro W7900, AMD Radeon Pro W7800, AMD Radeon RX 7900 XT
Other
No response
ROCm Version
ROCm 6.0.0
ROCm Component
Composable Kernel
Steps to Reproduce
No response
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
No response
I was trying to understand ck_tile and preparing to write fa kernels for 7900 series. But I am confused on tile window part. In old ck we can use threadgroup and thread slice transfer, but now we have to use tile_window. The params in tile window is hard to be understood. few comments :(
Hi @Googulator. Internal ticket has been created to assist with your issue. Thanks!
Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.
@demonsan Please open a separate issue for this if you haven't already so we can try to provide guidance and figure out where documentation can be added or improved.
Guys, I'm reopening this ticket because the developer asked a perfectly reasonable question about how to use ck. Redirecting to an internal ticket and then closing in favor of a ticket for support of specific kernels like that is not the right way.
I will discuss how we get some of these things done internally with you all, but in the meantime, can we answer the technical question here and at least provide some breadcrumbs on how to program ck for these cases?
I was trying to understand ck_tile and preparing to write fa kernels for 7900 series. But I am confused on tile window part. In old ck we can use threadgroup and thread slice transfer, but now we have to use tile_window. The params in tile window is hard to be understood. few comments :(
@carlos-amd can we answer this question?
hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.
hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.
That will be fantastic. I'll personally need to use a translator but having the knowledge be available is the most important thing. Thank you.
A ck tile implementation for gfx1100 would be so awesome. If i see it correctly, it will allow for flash attention and aiter, which in turn will allow to run things that rely on it like SGLang. Looking very much forward to your work!
hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.
That will be fantastic. I'll personally need to use a translator but having the knowledge be available is the most important thing. Thank you.
hi, the first post is about ck repo's code structre. Article about tile window and gemm kernel will be posted soon AMD Composable Kernel (CK Tile) Introduction - DEMON的文章 - 知乎 https://zhuanlan.zhihu.com/p/1907162735860496069
Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.
FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.
Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.
FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.
hi, i bought a 9070xt card recently. RDNA4 CK Tile Version is on plan. I will make some tests and support basic components like wmma eta first. Detailed info will be posted in several blogs. :)
Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.
FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.
Forward pass only, or also backward?
Forward pass only, or also backward?
It's mentioned in the link https://github.com/ROCm/flash-attention?tab=readme-ov-file#triton-backend
These features are supported in Fwd and Bwd
Fwd and Bwd with causal masking
Variable sequence lengths
Arbitrary Q and KV sequence lengths
Arbitrary head sizes
Multi and grouped query attention
Dropout
Rotary embeddings
ALiBi
@Googulator have you tried the pytorch scaled_dot_product_attention instead? It uses aotriton to implement flash attention (both fwd and bwd) WMMA based kernels via. triton on RDNA3+, and VRAM usage should be very similar as xformers (at least in my experience, YMMV but VRAM usage should be pretty good).
The latest version of pytorch comes with aotriton out of the box for linux, and if you want native windows support w/ ROCm+pytorch+aotriton, you can try my wheels built using TheRock here https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x
update some ck tile tensor access components details here. CK Tile 访存组件介绍 - DEMON的文章 - 知乎 https://zhuanlan.zhihu.com/p/1922322825718510877
Hi all! I wanted to inform you that we are working on WMMA support for FMHA: #2528
There are still some things to do (see a list in the PR) but most of functionality seems to work well on gfx12 (though, I don't know when we will start working on gfx11 support).