[Feature Request] Flexible MLA Kernel Supporting More `d`, `dv`, and `n_heads` Values
Background
Hi TileLang team!
Thank you for the fantastic repo — the performance of the MLA kernel is truly impressive.
We recently proposed a framework called MHA2MLA, which enables converting MHA or GQA-based models (such as Llama, Mistral, Qwen, etc.) to an MLA architecture at a very low cost (e.g., continued fine-tuning with only 0.3% of the original pretraining tokens). Due to constraints from the original MHA/GQA dimensions, the converted MLA cannot be perfectly aligned with the deepseek MLA (for instance, Llama2-7B-d_kv_16 has n_heads=32, dv=512, and d=512+16), meaning it cannot directly use high-efficiency kernels like FlashMLA that are strictly dimension-bound.
We noticed that TileLang_MLA offers comparable inference efficiency to FlashMLA while allowing flexible dimension adjustments. In our initial experiments, it runs stably with configurations such as n_heads=64, dv=512, and d=512+32. To match these supported shapes, we had to zero-pad models like Llama2-7B-d_kv_16 (n_heads=32 → 64, d=512+16 → 512+32), and as a result, our TileLang+MLA version achieved faster inference than FlashAttention2+MHA on H100-80G (please see table below).
| LLaMA2-7B(d_kv_16@MLA) | Bsz | Seqlen | Attention's Latency [ms] |
|---|---|---|---|
| MHA w. FlashAttention2 | 8 | 2K | 5.95 |
| 8 | 4K | 10.66 | |
| 64 | 2K | 43.74 | |
| MLA w. TileLang Kernel | 8 | 2K | 3.52 (-41%) |
| 8 | 4K | 6.08 (-43%) | |
| 64 | 2K | 4.10 (-91%) |
However, zero-padding may waste memory and compute. Since our kernel programming experience is limited, we would greatly appreciate your help in making the MLA kernel more flexible.
Feature Requests
Feature 1: Support for More Flexible Values of n_heads, dv, and d
n_heads: e.g., 16, 24, 32, etc.dv: e.g., 64, 128, 256, 512, etc.d: e.g.,dv+8,dv+16,dv+32,dv+64, etc.
Feature 2: Support for Non-shared k_r Across Attention Heads
- In the original MLA, the
k_r(the portion of the key vector with positional encoding) has the shape(bsz*seqlen, 1, d_pe). - To align with pretrained MHA parameters in MHA2MLA,
k_rshould be modified to(bsz*seqlen, n_heads, d_pe).
Would the TileLang team be willing to help us develop a more flexible MLA kernel? If not, could you offer some guidance on how to approach this? We would be happy to acknowledge your contributions in our Paper, GitHub repo, and HuggingFace Models.
Thanks again for your excellent work, and we look forward to your thoughts!
Best regards
Hi Jitao, we’re glad to hear that Tilelang has been helpful in your high-performance kernel development. Thank you for your request — your requirements seem quite similar to the GQA implementation we previously developed: GQA Example. Our team will review your request. cc @xs-keju
Hi @chengyupku, thanks for the quick response! I'll take a close look at the GQA implementation, and I'll be sure to share any progress or challenges as they come up. Looking forward to a successful integration of MHA2MLA and TileLang_MLA!
Hi @chengyupku, I have tried to work on this and the reference program of @JT-Ushio's requests is attached below.
The major difference between MHA2MLA and MLA is that the RoPE part will fall back to GEMV, as different RoPE queries will need to multiply with different RoPE keys.
import torch
import torch.nn.functional as F
import argparse
from tilelang.autotuner import *
import argparse
def mha2mla_ref(q, q_pe, kv, k_pe):
"""
Inputs:
- q (Tensor): [batch, q_head_num, q_len, dim]
- q_pe (Tensor): [batch, q_head_num, q_len, pe_dim]
- kv (Tensor): [batch, kv_head_num, kv_len, dim]
- k_pe (Tensor): [batch, q_head_num, kv_len, pe_dim]
Outputs:
- output (Tensor): [batch, q_head_num, dim]
"""
batch = q.shape[0]
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
scale = (dim + pe_dim)**0.5
seqlen_q = q.shape[1]
kv_head_num = kv.shape[1]
q_head_num = q.shape[2]
assert seqlen_q == 1, "assuming the decoding stage. seqlen_q should be 1"
assert kv_head_num == 1, "assuming kv_head_num is 1"
num_head_groups = q_head_num // kv_head_num
# part 1: scores
## A. NoPE parts (GEMM)
q = q.view(batch, kv_head_num, num_head_groups*1, dim) # [batch, kv_head_num, num_head_groups*1, dim]
kv_t = kv.transpose(-1, -2) # [batch, kv_head_num, dim, kv_len]
scores_no_rope = torch.matmul(q, kv_t) # [batch, kv_head_num, num_head_groups*1, kv_len] Ex: (bsz, 1, 32, 8192)
## B. RoPE parts (GEMV)
q_pe = q_pe.view(batch, num_head_groups*1, kv_head_num, pe_dim) # [batch, kv_head_num, num_head_groups*1, pe_dim]
kv_pe_t = k_pe.transpose(-1, -2) # [batch, q_head_num, pe_dim, kv_len]
scores_rope = torch.matmul(q_pe, kv_pe_t) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
## C. Combine scores
scores = scores_no_rope.transpose(1, 2) + scores_rope # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
scores = scores / scale # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
# part 2: Apply softmax
attention = F.softmax(scores, dim=-1) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
# part 3: Compute output
out = torch.matmul(attention, kv) # [batch, q_head_num, q_len, dim] Ex: (bsz, 32, 1, 512)
out = out.squeeze(2) # [batch, q_head_num, dim] Ex: (bsz, 32, 512)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
# Create random tensors for inputs
q = torch.randn(batch, 1, heads, dim, device="cuda", dtype=torch.float16) # [batch, q_len, q_head_num, dim]
q_pe = torch.randn(batch, 1, heads, pe_dim, device="cuda", dtype=torch.float16) # [batch, q_len, q_head_num, pe_dim]
kv = torch.randn(batch, kv_heads, kv_ctx, dim, device="cuda", dtype=torch.float16) # [batch, kv_head_num, kv_len, dim]
k_pe = torch.randn(batch, heads, kv_ctx, pe_dim, device="cuda", dtype=torch.float16) # [batch, q_head_num, kv_len, pe_dim]
# Run the reference program
ref_output = mha2mla_ref(q, q_pe, kv, k_pe)
print(ref_output.shape)
Hi @chengyupku, I have tried to work on this and the reference program of @JT-Ushio's requests is attached below.
The major difference between MHA2MLA and MLA is that the RoPE part will fall back to GEMV, as different RoPE queries will need to multiply with different RoPE keys.
import torch import torch.nn.functional as F import argparse from tilelang.autotuner import * import argparse
def mha2mla_ref(q, q_pe, kv, k_pe): """ Inputs: - q (Tensor): [batch, q_head_num, q_len, dim] - q_pe (Tensor): [batch, q_head_num, q_len, pe_dim] - kv (Tensor): [batch, kv_head_num, kv_len, dim] - k_pe (Tensor): [batch, q_head_num, kv_len, pe_dim]
Outputs: - output (Tensor): [batch, q_head_num, dim] """ batch = q.shape[0] dim = q.shape[-1] pe_dim = q_pe.shape[-1] scale = (dim + pe_dim)**0.5 seqlen_q = q.shape[1] kv_head_num = kv.shape[1] q_head_num = q.shape[2] assert seqlen_q == 1, "assuming the decoding stage. seqlen_q should be 1" assert kv_head_num == 1, "assuming kv_head_num is 1" num_head_groups = q_head_num // kv_head_num# part 1: scores ## A. NoPE parts (GEMM) q = q.view(batch, kv_head_num, num_head_groups*1, dim) # [batch, kv_head_num, num_head_groups*1, dim] kv_t = kv.transpose(-1, -2) # [batch, kv_head_num, dim, kv_len] scores_no_rope = torch.matmul(q, kv_t) # [batch, kv_head_num, num_head_groups*1, kv_len] Ex: (bsz, 1, 32, 8192) ## B. RoPE parts (GEMV) q_pe = q_pe.view(batch, num_head_groups*1, kv_head_num, pe_dim) # [batch, kv_head_num, num_head_groups*1, pe_dim] kv_pe_t = k_pe.transpose(-1, -2) # [batch, q_head_num, pe_dim, kv_len] scores_rope = torch.matmul(q_pe, kv_pe_t) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192) ## C. Combine scores scores = scores_no_rope.transpose(1, 2) + scores_rope # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192) scores = scores / scale # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192) # Apply softmax attention = F.softmax(scores, dim=-1) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192) # Compute output out = torch.matmul(attention, kv) # [batch, q_head_num, q_len, dim] Ex: (bsz, 32, 1, 512) out = out.squeeze(2) # [batch, q_head_num, dim] Ex: (bsz, 32, 512) return outif name == "main": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument('--heads', type=int, default=32, help='q heads number') parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
# Create random tensors for inputs q = torch.randn(batch, 1, heads, dim, device="cuda", dtype=torch.float16) # [batch, q_len, q_head_num, dim] q_pe = torch.randn(batch, 1, heads, pe_dim, device="cuda", dtype=torch.float16) # [batch, q_len, q_head_num, pe_dim] kv = torch.randn(batch, kv_heads, kv_ctx, dim, device="cuda", dtype=torch.float16) # [batch, kv_head_num, kv_len, dim] k_pe = torch.randn(batch, heads, kv_ctx, pe_dim, device="cuda", dtype=torch.float16) # [batch, q_head_num, kv_len, pe_dim] # Run the reference program ref_output = mha2mla_ref(q, q_pe, kv, k_pe) print(ref_output.shape)
Thanks for the explanation, Chi-Chih! This will help guide our implementation of the Tilelang kernel.
Hi @chengyupku, thank you!
Yesterday I started modifying example_mla_decode.py with a minimal change in mind—specifically, replacing the GEMM operation between Q_pe_shared and K_pe_shared with a Batched GEMV instead.
Here's the relevant snippet:
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * bz + k * block_N
kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
###### Change this line!!!! #####
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
I'm currently exploring examples to figure out how to implement this switch to Batched GEMV correctly. If you happen to have any insights or can point me to relevant examples, I’d really appreciate it! I am currently looking into this example_gemv.py!