tilelang icon indicating copy to clipboard operation
tilelang copied to clipboard

[Feature Request] Flexible MLA Kernel Supporting More `d`, `dv`, and `n_heads` Values

Open JT-Ushio opened this issue 9 months ago • 5 comments

Background

Hi TileLang team!
Thank you for the fantastic repo — the performance of the MLA kernel is truly impressive.

We recently proposed a framework called MHA2MLA, which enables converting MHA or GQA-based models (such as Llama, Mistral, Qwen, etc.) to an MLA architecture at a very low cost (e.g., continued fine-tuning with only 0.3% of the original pretraining tokens). Due to constraints from the original MHA/GQA dimensions, the converted MLA cannot be perfectly aligned with the deepseek MLA (for instance, Llama2-7B-d_kv_16 has n_heads=32, dv=512, and d=512+16), meaning it cannot directly use high-efficiency kernels like FlashMLA that are strictly dimension-bound.

We noticed that TileLang_MLA offers comparable inference efficiency to FlashMLA while allowing flexible dimension adjustments. In our initial experiments, it runs stably with configurations such as n_heads=64, dv=512, and d=512+32. To match these supported shapes, we had to zero-pad models like Llama2-7B-d_kv_16 (n_heads=32 → 64, d=512+16 → 512+32), and as a result, our TileLang+MLA version achieved faster inference than FlashAttention2+MHA on H100-80G (please see table below).

LLaMA2-7B(d_kv_16@MLA) Bsz Seqlen Attention's Latency [ms]
MHA w. FlashAttention2 8 2K 5.95
8 4K 10.66
64 2K 43.74
MLA w. TileLang Kernel 8 2K 3.52 (-41%)
8 4K 6.08 (-43%)
64 2K 4.10 (-91%)

However, zero-padding may waste memory and compute. Since our kernel programming experience is limited, we would greatly appreciate your help in making the MLA kernel more flexible.

Feature Requests

Feature 1: Support for More Flexible Values of n_heads, dv, and d

  • n_heads: e.g., 16, 24, 32, etc.
  • dv: e.g., 64, 128, 256, 512, etc.
  • d: e.g., dv+8, dv+16, dv+32, dv+64, etc.

Feature 2: Support for Non-shared k_r Across Attention Heads

  • In the original MLA, the k_r (the portion of the key vector with positional encoding) has the shape (bsz*seqlen, 1, d_pe).
  • To align with pretrained MHA parameters in MHA2MLA, k_r should be modified to (bsz*seqlen, n_heads, d_pe).

Would the TileLang team be willing to help us develop a more flexible MLA kernel? If not, could you offer some guidance on how to approach this? We would be happy to acknowledge your contributions in our Paper, GitHub repo, and HuggingFace Models.

Thanks again for your excellent work, and we look forward to your thoughts!

Best regards

JT-Ushio avatar Apr 07 '25 05:04 JT-Ushio

Hi Jitao, we’re glad to hear that Tilelang has been helpful in your high-performance kernel development. Thank you for your request — your requirements seem quite similar to the GQA implementation we previously developed: GQA Example. Our team will review your request. cc @xs-keju

chengyupku avatar Apr 07 '25 06:04 chengyupku

Hi @chengyupku, thanks for the quick response! I'll take a close look at the GQA implementation, and I'll be sure to share any progress or challenges as they come up. Looking forward to a successful integration of MHA2MLA and TileLang_MLA!

JT-Ushio avatar Apr 07 '25 06:04 JT-Ushio

Hi @chengyupku, I have tried to work on this and the reference program of @JT-Ushio's requests is attached below.

The major difference between MHA2MLA and MLA is that the RoPE part will fall back to GEMV, as different RoPE queries will need to multiply with different RoPE keys.

import torch
import torch.nn.functional as F
import argparse
from tilelang.autotuner import *
import argparse


def mha2mla_ref(q, q_pe, kv, k_pe):
    """
    Inputs:
    - q (Tensor): [batch, q_head_num, q_len, dim]
    - q_pe (Tensor): [batch, q_head_num, q_len, pe_dim]
    - kv (Tensor): [batch, kv_head_num, kv_len, dim]
    - k_pe (Tensor): [batch, q_head_num, kv_len, pe_dim]  
    Outputs:
    - output (Tensor): [batch, q_head_num, dim]
    """
    batch = q.shape[0]
    dim = q.shape[-1]
    pe_dim = q_pe.shape[-1]
    scale = (dim + pe_dim)**0.5
    seqlen_q = q.shape[1]
    kv_head_num = kv.shape[1]
    q_head_num = q.shape[2]
    assert seqlen_q == 1, "assuming the decoding stage. seqlen_q should be 1"
    assert kv_head_num == 1, "assuming kv_head_num is 1"
    num_head_groups = q_head_num // kv_head_num

    # part 1: scores  
    ## A. NoPE parts (GEMM)
    q = q.view(batch, kv_head_num, num_head_groups*1, dim) # [batch, kv_head_num, num_head_groups*1, dim]
    kv_t = kv.transpose(-1, -2) # [batch, kv_head_num, dim, kv_len]
    scores_no_rope = torch.matmul(q, kv_t) # [batch, kv_head_num, num_head_groups*1, kv_len] Ex: (bsz, 1, 32, 8192)
    
    ## B. RoPE parts (GEMV)
    q_pe = q_pe.view(batch, num_head_groups*1, kv_head_num, pe_dim) # [batch, kv_head_num, num_head_groups*1, pe_dim]
    kv_pe_t = k_pe.transpose(-1, -2) # [batch, q_head_num, pe_dim, kv_len]
    scores_rope = torch.matmul(q_pe, kv_pe_t) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
    
    ## C. Combine scores
    scores = scores_no_rope.transpose(1, 2) + scores_rope # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
    scores = scores / scale # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)

    # part 2: Apply softmax
    attention = F.softmax(scores, dim=-1)  # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
    # part 3: Compute output
    out = torch.matmul(attention, kv) # [batch, q_head_num, q_len, dim] Ex: (bsz, 32, 1, 512)
    out = out.squeeze(2) # [batch, q_head_num, dim] Ex: (bsz, 32, 512)
    return out


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch', type=int, default=1, help='batch size')
    parser.add_argument('--heads', type=int, default=32, help='q heads number')
    parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
    parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
    parser.add_argument('--dim', type=int, default=512, help='head dim')
    parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
    args = parser.parse_args()
    batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
    
    # Create random tensors for inputs
    q = torch.randn(batch, 1, heads, dim, device="cuda", dtype=torch.float16)  # [batch, q_len, q_head_num, dim]
    q_pe = torch.randn(batch, 1, heads, pe_dim, device="cuda", dtype=torch.float16)  # [batch, q_len, q_head_num, pe_dim]
    kv = torch.randn(batch, kv_heads, kv_ctx, dim, device="cuda", dtype=torch.float16)  # [batch, kv_head_num, kv_len, dim]
    k_pe = torch.randn(batch, heads, kv_ctx, pe_dim, device="cuda", dtype=torch.float16)  # [batch, q_head_num, kv_len, pe_dim]
    
    # Run the reference program
    ref_output = mha2mla_ref(q, q_pe, kv, k_pe)
    print(ref_output.shape)

shadowpa0327 avatar Apr 08 '25 20:04 shadowpa0327

Hi @chengyupku, I have tried to work on this and the reference program of @JT-Ushio's requests is attached below.

The major difference between MHA2MLA and MLA is that the RoPE part will fall back to GEMV, as different RoPE queries will need to multiply with different RoPE keys.

import torch import torch.nn.functional as F import argparse from tilelang.autotuner import * import argparse

def mha2mla_ref(q, q_pe, kv, k_pe): """ Inputs: - q (Tensor): [batch, q_head_num, q_len, dim] - q_pe (Tensor): [batch, q_head_num, q_len, pe_dim] - kv (Tensor): [batch, kv_head_num, kv_len, dim] - k_pe (Tensor): [batch, q_head_num, kv_len, pe_dim]
Outputs: - output (Tensor): [batch, q_head_num, dim] """ batch = q.shape[0] dim = q.shape[-1] pe_dim = q_pe.shape[-1] scale = (dim + pe_dim)**0.5 seqlen_q = q.shape[1] kv_head_num = kv.shape[1] q_head_num = q.shape[2] assert seqlen_q == 1, "assuming the decoding stage. seqlen_q should be 1" assert kv_head_num == 1, "assuming kv_head_num is 1" num_head_groups = q_head_num // kv_head_num

# part 1: scores  
## A. NoPE parts (GEMM)
q = q.view(batch, kv_head_num, num_head_groups*1, dim) # [batch, kv_head_num, num_head_groups*1, dim]
kv_t = kv.transpose(-1, -2) # [batch, kv_head_num, dim, kv_len]
scores_no_rope = torch.matmul(q, kv_t) # [batch, kv_head_num, num_head_groups*1, kv_len] Ex: (bsz, 1, 32, 8192)

## B. RoPE parts (GEMV)
q_pe = q_pe.view(batch, num_head_groups*1, kv_head_num, pe_dim) # [batch, kv_head_num, num_head_groups*1, pe_dim]
kv_pe_t = k_pe.transpose(-1, -2) # [batch, q_head_num, pe_dim, kv_len]
scores_rope = torch.matmul(q_pe, kv_pe_t) # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)

## C. Combine scores
scores = scores_no_rope.transpose(1, 2) + scores_rope # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
scores = scores / scale # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)

# Apply softmax
attention = F.softmax(scores, dim=-1)  # [batch, q_head_num, q_len, kv_len] Ex: (bsz, 32, 1, 8192)
# Compute output
out = torch.matmul(attention, kv) # [batch, q_head_num, q_len, dim] Ex: (bsz, 32, 1, 512)
out = out.squeeze(2) # [batch, q_head_num, dim] Ex: (bsz, 32, 512)
return out

if name == "main": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument('--heads', type=int, default=32, help='q heads number') parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number') parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length') parser.add_argument('--dim', type=int, default=512, help='head dim') parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim') args = parser.parse_args() batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim

# Create random tensors for inputs
q = torch.randn(batch, 1, heads, dim, device="cuda", dtype=torch.float16)  # [batch, q_len, q_head_num, dim]
q_pe = torch.randn(batch, 1, heads, pe_dim, device="cuda", dtype=torch.float16)  # [batch, q_len, q_head_num, pe_dim]
kv = torch.randn(batch, kv_heads, kv_ctx, dim, device="cuda", dtype=torch.float16)  # [batch, kv_head_num, kv_len, dim]
k_pe = torch.randn(batch, heads, kv_ctx, pe_dim, device="cuda", dtype=torch.float16)  # [batch, q_head_num, kv_len, pe_dim]

# Run the reference program
ref_output = mha2mla_ref(q, q_pe, kv, k_pe)
print(ref_output.shape)

Thanks for the explanation, Chi-Chih! This will help guide our implementation of the Tilelang kernel.

chengyupku avatar Apr 09 '25 04:04 chengyupku

Hi @chengyupku, thank you!

Yesterday I started modifying example_mla_decode.py with a minimal change in mind—specifically, replacing the GEMM operation between Q_pe_shared and K_pe_shared with a Batched GEMV instead.

Here's the relevant snippet:

 for k in T.Pipelined(loop_range, num_stages=2):
      kv_start = (seqlen_kv // num_split) * bz + k * block_N
      kv_end = (seqlen_kv // num_split) * bz + (k + 1) * block_N
      T.copy(KV[bx, kv_start:kv_end, cur_kv_head, :], KV_shared)
      T.copy(K_pe[bx, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
      T.clear(acc_s)
      T.gemm(
          Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
     ###### Change this line!!!! #####
      T.gemm(
          Q_pe_shared,
          K_pe_shared,
          acc_s,
          transpose_B=True,
          policy=T.GemmWarpPolicy.FullCol)

I'm currently exploring examples to figure out how to implement this switch to Batched GEMV correctly. If you happen to have any insights or can point me to relevant examples, I’d really appreciate it! I am currently looking into this example_gemv.py!

shadowpa0327 avatar Apr 09 '25 05:04 shadowpa0327