composable_kernel icon indicating copy to clipboard operation
composable_kernel copied to clipboard

WMMA / RDNA3+ kernels for backwards fused attention?

Open Googulator opened this issue 1 year ago • 16 comments

Problem Description

Composable Kernel currently only contains code to support fused attention (FA2) on RDNA3(+) architectures in the forward direction. This greatly increases the VRAM requirements for training LoRAs on LLMs using HuggingFace's Transformers and PEFT libraries - training jobs that succeed on an NVIDIA GeForce RTX 4080 with just 16GB VRAM fail on a Radeon RX 7900 XT with 20GB.

Based on https://github.com/Repeerc/flash-attention-v2-RDNA3-minimal and https://github.com/Repeerc/sd-webui-flash-attention2-rdna3-rocm, it seems possible to implement a usable WMMA-based backwards fused attention kernel - unfortunately I can't use these myself directly, as these are both tailored for image generation (Stable Diffusion), whereas I would be interested in FA2 support for LLM training instead.

Are there any plans for adding fused attention backward pass support for RDNA3+ GPUs to CK in the foreseeable future? This seems especially pressing with the W7900 Dual Slot, an RDNA3 GPU, being recommended for AI workstation usage, where the ability to make effective use of this GPU's 48GB VRAM during training feels a lot more of a core use case.

Operating System

Ubuntu 22.04 LTS

CPU

AMD Ryzen 9 7950X (non-3D)

GPU

AMD Radeon RX 7900 XTX, AMD Radeon Pro W7900, AMD Radeon Pro W7800, AMD Radeon RX 7900 XT

Other

No response

ROCm Version

ROCm 6.0.0

ROCm Component

Composable Kernel

Steps to Reproduce

No response

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Googulator avatar Aug 01 '24 22:08 Googulator

I was trying to understand ck_tile and preparing to write fa kernels for 7900 series. But I am confused on tile window part. In old ck we can use threadgroup and thread slice transfer, but now we have to use tile_window. The params in tile window is hard to be understood. few comments :(

demonsan avatar Oct 14 '24 05:10 demonsan

Hi @Googulator. Internal ticket has been created to assist with your issue. Thanks!

ppanchad-amd avatar Nov 01 '24 13:11 ppanchad-amd

Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.

@demonsan Please open a separate issue for this if you haven't already so we can try to provide guidance and figure out where documentation can be added or improved.

schung-amd avatar Nov 01 '24 15:11 schung-amd

Guys, I'm reopening this ticket because the developer asked a perfectly reasonable question about how to use ck. Redirecting to an internal ticket and then closing in favor of a ticket for support of specific kernels like that is not the right way.

I will discuss how we get some of these things done internally with you all, but in the meantime, can we answer the technical question here and at least provide some breadcrumbs on how to program ck for these cases?

stellaraccident avatar May 06 '25 12:05 stellaraccident

I was trying to understand ck_tile and preparing to write fa kernels for 7900 series. But I am confused on tile window part. In old ck we can use threadgroup and thread slice transfer, but now we have to use tile_window. The params in tile window is hard to be understood. few comments :(

@carlos-amd can we answer this question?

stellaraccident avatar May 06 '25 12:05 stellaraccident

hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.

demonsan avatar May 06 '25 13:05 demonsan

hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.

That will be fantastic. I'll personally need to use a translator but having the knowledge be available is the most important thing. Thank you.

stellaraccident avatar May 06 '25 13:05 stellaraccident

A ck tile implementation for gfx1100 would be so awesome. If i see it correctly, it will allow for flash attention and aiter, which in turn will allow to run things that rely on it like SGLang. Looking very much forward to your work!

DrZoidberg09 avatar May 11 '25 14:05 DrZoidberg09

hi,i have figured out most of ck tile components meaning and written some kernels to check . currently, i am writing some docs on them. first verision will be chinese one. is that ok? i can put some zhihu links here once the documents are ready.

That will be fantastic. I'll personally need to use a translator but having the knowledge be available is the most important thing. Thank you.

hi, the first post is about ck repo's code structre. Article about tile window and gemm kernel will be posted soon AMD Composable Kernel (CK Tile) Introduction - DEMON的文章 - 知乎 https://zhuanlan.zhihu.com/p/1907162735860496069

demonsan avatar May 18 '25 01:05 demonsan

Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.

FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.

jammm avatar May 26 '25 19:05 jammm

Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.

FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.

hi, i bought a 9070xt card recently. RDNA4 CK Tile Version is on plan. I will make some tests and support basic components like wmma eta first. Detailed info will be posted in several blogs. :)

demonsan avatar May 28 '25 00:05 demonsan

Hi @Googulator, CK currently only supports FA for MI-series cards; for example, https://github.com/ROCm/flash-attention has forward and backward attention with a CK backend for MI200 and MI300, but not on RDNA3. We're aware that better FA support is needed for Radeon cards, especially considering it would further leverage the VRAM advantage of these cards as you mention, and improvements are in the pipeline although there isn't a specific timeline I can provide for this.

FWIW, I see that as of today, flash-attention repo does support Radeon cards via. the triton backend.

Forward pass only, or also backward?

Googulator avatar May 30 '25 08:05 Googulator

Forward pass only, or also backward?

It's mentioned in the link https://github.com/ROCm/flash-attention?tab=readme-ov-file#triton-backend

These features are supported in Fwd and Bwd

Fwd and Bwd with causal masking
Variable sequence lengths
Arbitrary Q and KV sequence lengths
Arbitrary head sizes
Multi and grouped query attention
Dropout
Rotary embeddings
ALiBi

jammm avatar May 30 '25 21:05 jammm

@Googulator have you tried the pytorch scaled_dot_product_attention instead? It uses aotriton to implement flash attention (both fwd and bwd) WMMA based kernels via. triton on RDNA3+, and VRAM usage should be very similar as xformers (at least in my experience, YMMV but VRAM usage should be pretty good).

The latest version of pytorch comes with aotriton out of the box for linux, and if you want native windows support w/ ROCm+pytorch+aotriton, you can try my wheels built using TheRock here https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x

jammm avatar May 30 '25 21:05 jammm

update some ck tile tensor access components details here. CK Tile 访存组件介绍 - DEMON的文章 - 知乎 https://zhuanlan.zhihu.com/p/1922322825718510877

demonsan avatar Jul 11 '25 02:07 demonsan

Hi all! I wanted to inform you that we are working on WMMA support for FMHA: #2528

There are still some things to do (see a list in the PR) but most of functionality seems to work well on gfx12 (though, I don't know when we will start working on gfx11 support).

ex-rzr avatar Jul 18 '25 12:07 ex-rzr