[BUG] fp8_lighting_indexer.py very slow with more than 8 heads
Required prerequisites
- [x] I have read the documentation https://tilelang.com.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.6.post1+cu128.gitd9a0f131
System information
/opt/conda/lib/python3.10/runpy.py:126: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
warn(RuntimeWarning(msg))
Collecting environment information...
PyTorch version: 2.9.1+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.1 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 4.1.2
Libc version: glibc-2.39
Python version: 3.10.19 (main, Oct 21 2025, 16:43:05) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.8.0-87-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.8.61
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Nvidia driver version: 570.124.06
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.7.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.7.0
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9355 32-Core Processor
CPU family: 26
Model: 2
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 101%
CPU max MHz: 3550.0000
CPU min MHz: 1500.0000
BogoMIPS: 7099.85
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx_vnni avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc amd_ibpb_ret arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect movdiri movdir64b overflow_recov succor smca fsrm avx512_vp2intersect flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 64 MiB (64 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 8
NUMA node0 CPU(s): 0-7,64-71
NUMA node1 CPU(s): 8-15,72-79
NUMA node2 CPU(s): 16-23,80-87
NUMA node3 CPU(s): 24-31,88-95
NUMA node4 CPU(s): 32-39,96-103
NUMA node5 CPU(s): 40-47,104-111
NUMA node6 CPU(s): 48-55,112-119
NUMA node7 CPU(s): 56-63,120-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Not affected
Versions of relevant libraries:
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.8.4.1
[pip3] nvidia-cuda-cupti-cu12==12.8.90
[pip3] nvidia-cuda-nvrtc-cu12==12.8.93
[pip3] nvidia-cuda-runtime-cu12==12.8.90
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.3.83
[pip3] nvidia-curand-cu12==10.3.9.90
[pip3] nvidia-cusolver-cu12==11.7.3.90
[pip3] nvidia-cusparse-cu12==12.5.8.93
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.8.93
[pip3] nvidia-nvtx-cu12==12.8.90
[pip3] torch==2.9.1
[pip3] triton==3.5.1
[conda] numpy 2.2.6 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.1 pypi_0 pypi
[conda] triton 3.5.1 pypi_0 pypi
Problem description
I have a benchmark script I wrote for fp8_lighting_indexer.py:
bench_indexer_tilelang.py
#!/usr/bin/env python3
import argparse
import torch
import os
import sys
from typing import Optional
# Optional TVM runtime import to dump CUDA/PTX sources
import tilelang
from tilelang import tvm
from tvm import runtime as tvm_rt
# Prefer local examples path resolution if running from repo root
try:
from examples.deepseek_v32.utils import per_custom_dims_cast_to_fp8 as _to_fp8
def to_fp8(x):
# Cast along last dim to FP8 E4M3 to match kernel expectations
# Handle both (x, dims, use_ue8m0) and (x, dims) signatures and return the scaled tensor only.
try:
x_scaled, _ = _to_fp8(x, dims=(-1,), use_ue8m0=False)
return x_scaled
except TypeError:
out = _to_fp8(x, dims=(-1,))
return out[0] if isinstance(out, tuple) else out
except Exception:
def to_fp8(x):
# Fallback: use PyTorch FP8 E4M3 if TileLang utils are unavailable
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.")
return x.to(torch.float8_e4m3fn)
# Try to ensure TVM runtime is importable (vendored TVM + build lib dirs)
def _ensure_tvm_runtime() -> bool:
global tvm_rt
if tvm_rt is not None:
return True
# First, try importing directly
try:
from tvm import runtime as _rt # type: ignore
tvm_rt = _rt # type: ignore
return True
except Exception:
pass
# Add vendored TVM python path
try:
here = os.path.abspath(os.path.dirname(__file__))
root = here
vendored = os.path.join(root, "3rdparty", "tvm", "python")
if os.path.isdir(vendored) and vendored not in sys.path:
sys.path.insert(0, vendored)
# Add build library dirs for TVM
libdirs = [os.path.join(root, "build", "tvm"), os.path.join(root, "build", "lib")]
libdirs = [p for p in libdirs if os.path.isdir(p)]
if libdirs:
sep = ":" if os.name != "nt" else ";"
add = sep.join(libdirs)
os.environ["TVM_LIBRARY_PATH"] = add + (sep + os.environ["TVM_LIBRARY_PATH"] if "TVM_LIBRARY_PATH" in os.environ else "")
os.environ["LD_LIBRARY_PATH"] = add + (sep + os.environ["LD_LIBRARY_PATH"] if "LD_LIBRARY_PATH" in os.environ else "")
from tvm import runtime as _rt # type: ignore
tvm_rt = _rt # type: ignore
return True
except Exception:
return False
# Fallback: use PyTorch FP8 E4M3 if TileLang utils are unavailable
if not hasattr(torch, "float8_e4m3fn"):
raise RuntimeError("torch.float8_e4m3fn not available; install a CUDA-enabled PyTorch.")
return x.to(torch.float8_e4m3fn)
# Utilities to extract CUDA/PTX sources from compiled TileLang kernels
def _get_rt_mod_from_kernel(kernel) -> Optional["tvm_rt.Module"]:
if tvm_rt is None:
return None
# Direct attachments
for attr in ("rt_mod", "module", "mod"):
m = getattr(kernel, attr, None)
if m is not None and isinstance(m, tvm_rt.Module):
return m
# Nested wrappers commonly used by TileLang frontends
for inner_name in ("impl", "fn", "kernel", "launcher"):
inner = getattr(kernel, inner_name, None)
if inner is None:
continue
for attr in ("rt_mod", "module", "mod"):
m = getattr(inner, attr, None)
if m is not None and isinstance(m, tvm_rt.Module):
return m
return None
# Prefer printing CUDA source from TileLang artifact if available
def _print_kernel_cuda_from_artifact(kernel, kernel_name: str) -> bool:
try:
art = getattr(kernel, "artifact", None)
if art is not None:
src = getattr(art, "kernel_source", None)
if src:
print(f"===== BEGIN {kernel_name} CUDA =====")
print(src)
print(f"===== END {kernel_name} CUDA =====")
return True
except Exception as e:
print(f"[KERNEL_SRC] Artifact kernel_source not available for {kernel_name}: {e}")
return False
def _print_kernel_sources(kernel, kernel_name: str):
if tvm_rt is None:
if not _ensure_tvm_runtime():
print(f"[KERNEL_SRC] TVM runtime not available; cannot dump sources for {kernel_name}")
return
try:
rt_mod = _get_rt_mod_from_kernel(kernel)
if rt_mod is None:
print(f"[KERNEL_SRC] No TVM runtime module found on kernel {kernel_name}")
return
# Many CUDA builds store device code in imported_modules[0].
try:
imported = list(rt_mod.imported_modules)
except Exception:
imported = []
if not imported:
imported = [rt_mod]
device_mod = imported[0]
any_src = False
for fmt in ("cuda", "ptx"):
try:
src = device_mod.get_source(fmt)
except Exception:
src = None
if src:
any_src = True
header = f"===== BEGIN {kernel_name} {fmt.upper()} ====="
footer = f"===== END {kernel_name} {fmt.upper()} ====="
print(header)
print(src)
print(footer)
if not any_src:
print(f"[KERNEL_SRC] Module did not expose CUDA/PTX sources for {kernel_name}")
except Exception as e:
print(f"[KERNEL_SRC] Failed to get sources for {kernel_name}: {e}")
# TileLang example kernels for lightning indexer
from examples.deepseek_v32.fp8_lighting_indexer import (
mqa_attn_return_logits,
mqa_attn_return_logits_interface,
)
def bench_tl_indexer_wrapper(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
# Precompute kv_scales similar to reference: sqrt(mean(k^2)) along dim=-1
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
# Warmup
for _ in range(warmup):
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
_ = mqa_attn_return_logits_interface(q_fp8, kv_fp8, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] WRAPPER S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
# Dump kernel source (prefer artifact.kernel_source)
try:
kwrap = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Force a tiny compile-run to populate artifact
S_sm, SKV_sm = 32, 32
q_sm = torch.randn(S_sm, heads, index_dim, device=device, dtype=torch.float32)
kv_sm = torch.randn(SKV_sm, index_dim, device=device, dtype=torch.float32)
q_sm_fp8 = to_fp8(q_sm)
kv_sm_fp8 = to_fp8(kv_sm)
kv_scales_sm = kv_sm.pow(2).mean(dim=-1).sqrt()
weights_sm = torch.randn(S_sm, heads, device=device, dtype=torch.float32)
cu_seqlen_ks_sm = torch.zeros(S_sm, dtype=torch.int32, device=device)
cu_seqlen_ke_sm = torch.full((S_sm,), SKV_sm, dtype=torch.int32, device=device)
logits_sm = torch.empty(S_sm, SKV_sm, device=device, dtype=torch.float32)
kwrap(q_sm_fp8.view(S_sm * heads, index_dim), kv_sm_fp8, kv_scales_sm,
logits_sm, weights_sm, cu_seqlen_ks_sm, cu_seqlen_ke_sm)
torch.cuda.synchronize()
if not _print_kernel_cuda_from_artifact(kwrap, "mqa_attn_return_logits_kernel"):
_print_kernel_sources(kwrap, "mqa_attn_return_logits_kernel")
except Exception as e:
print(f"[KERNEL_SRC] Wrapper: unable to dump kernel sources: {e}")
def bench_tl_indexer_impl(seq_len: int,
seq_len_kv: int,
heads: int = 4,
index_dim: int = 64,
iters: int = 50,
warmup: int = 5):
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
device = torch.device("cuda")
# Compile kernel once
kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
# Inputs
q = torch.randn(seq_len, heads, index_dim, device=device, dtype=torch.float32)
kv = torch.randn(seq_len_kv, index_dim, device=device, dtype=torch.float32)
# Convert to FP8 E4M3 to match kernel signature
q_fp8 = to_fp8(q)
kv_fp8 = to_fp8(kv)
kv_scales = kv.pow(2).mean(dim=-1).sqrt()
weights = torch.randn(seq_len, heads, device=device, dtype=torch.float32)
cu_seqlen_ks = torch.zeros(seq_len, dtype=torch.int32, device=device)
cu_seqlen_ke = torch.full((seq_len,), seq_len_kv, dtype=torch.int32, device=device)
logits = torch.empty(seq_len, seq_len_kv, device=device, dtype=torch.float32)
# Warmup
for _ in range(warmup):
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
torch.cuda.synchronize()
# Timed
times = []
for _ in range(iters):
t0 = torch.cuda.Event(enable_timing=True)
t1 = torch.cuda.Event(enable_timing=True)
t0.record()
kernel(
q_fp8.view(seq_len * heads, index_dim),
kv_fp8,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
t1.record()
t1.synchronize()
times.append(t0.elapsed_time(t1)) # ms
avg_ms = sum(times) / len(times) if times else float("nan")
print(f"[TILELANG_INDEXER] IMPL S={seq_len} SKV={seq_len_kv} H={heads} D={index_dim} avg_ms={avg_ms:.3f} over {iters}")
# Dump kernel source for impl (prefer artifact.kernel_source)
try:
# kernel is already compiled by this point, but if artifact is not populated, force a tiny run
if not _print_kernel_cuda_from_artifact(kernel, "mqa_attn_return_logits_kernel"):
S_sm, SKV_sm = 32, 32
q_sm = torch.randn(S_sm, heads, index_dim, device=device, dtype=torch.float32)
kv_sm = torch.randn(SKV_sm, index_dim, device=device, dtype=torch.float32)
q_sm_fp8 = to_fp8(q_sm)
kv_sm_fp8 = to_fp8(kv_sm)
kv_scales_sm = kv_sm.pow(2).mean(dim=-1).sqrt()
weights_sm = torch.randn(S_sm, heads, device=device, dtype=torch.float32)
cu_seqlen_ks_sm = torch.zeros(S_sm, dtype=torch.int32, device=device)
cu_seqlen_ke_sm = torch.full((S_sm,), SKV_sm, dtype=torch.int32, device=device)
logits_sm = torch.empty(S_sm, SKV_sm, device=device, dtype=torch.float32)
kernel(q_sm_fp8.view(S_sm * heads, index_dim), kv_sm_fp8, kv_scales_sm,
logits_sm, weights_sm, cu_seqlen_ks_sm, cu_seqlen_ke_sm)
torch.cuda.synchronize()
if not _print_kernel_cuda_from_artifact(kernel, "mqa_attn_return_logits_kernel"):
_print_kernel_sources(kernel, "mqa_attn_return_logits_kernel")
except Exception as e:
print(f"[KERNEL_SRC] Impl: unable to dump kernel sources: {e}")
def parse_int_list(s: str):
vals = []
for part in s.split(','):
part = part.strip()
if not part:
continue
vals.append(int(part))
return vals
def main():
parser = argparse.ArgumentParser(description="Benchmark TileLang lightning indexer (DeepSeek V3.2)")
parser.add_argument("--seq-lens", type=parse_int_list, default="4096,16384,163840",
help="Comma-separated sequence lengths S (default: 4096,16384,163840)")
parser.add_argument("--kv-lens", type=parse_int_list, default=None,
help="Comma-separated KV lengths SKV; if omitted, uses seq-lens")
parser.add_argument("--heads", type=int, default=4, help="Indexer heads H (default: 4)")
parser.add_argument("--dim", type=int, default=64, help="Indexer dimension D (default: 64)")
parser.add_argument("--iters", type=int, default=50, help="Timed iterations (default: 50)")
parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations (default: 5)")
parser.add_argument("--mode", choices=["both", "wrapper", "impl"], default="both",
help="Which path to benchmark: wrapper (interface), impl (kernel), or both (default)")
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for this benchmark.")
dev = torch.cuda.get_device_name(0)
print(f"CUDA device: {dev}")
seq_lens = args.seq_lens if isinstance(args.seq_lens, list) else parse_int_list(args.seq_lens)
kv_lens = None
if args.kv_lens is None:
kv_lens = seq_lens
else:
kv_lens = args.kv_lens if isinstance(args.kv_lens, list) else parse_int_list(args.kv_lens)
if len(kv_lens) != len(seq_lens):
raise ValueError("--kv-lens must have the same number of elements as --seq-lens")
for S, SKV in zip(seq_lens, kv_lens):
if args.mode in ("both", "wrapper"):
bench_tl_indexer_wrapper(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if args.mode in ("both", "impl"):
bench_tl_indexer_impl(S, SKV, heads=args.heads, index_dim=args.dim, iters=args.iters, warmup=args.warmup)
if __name__ == "__main__":
main()
If I run this with 4 heads, like this:
python bench_indexer_tilelang.py --heads 4 --seq-lens '4096,16384,150840'
I get results quickly:
2025-11-14 00:40:19 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
[TILELANG_INDEXER] WRAPPER S=4096 SKV=4096 H=4 D=64 avg_ms=0.062 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=4096 SKV=4096 H=4 D=64 avg_ms=0.050 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=16384 SKV=16384 H=4 D=64 avg_ms=0.776 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=16384 SKV=16384 H=4 D=64 avg_ms=0.701 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=150840 SKV=150840 H=4 D=64 avg_ms=65.442 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=150840 SKV=150840 H=4 D=64 avg_ms=59.975 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
If I run it with 8 heads, and I reduce the seq-lens a bit, I still get results, but they're slower:
root@a1cb74468f35:~/TileLang# python bench_indexer_tilelang.py --heads 8 --seq-lens '4096,16384,150000'
2025-11-14 00:44:36 [TileLang:tilelang.env:WARNING]: Loading tilelang libs from dev root: /root/TileLang/build
CUDA device: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
[TILELANG_INDEXER] WRAPPER S=4096 SKV=4096 H=8 D=64 avg_ms=0.112 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=4096 SKV=4096 H=8 D=64 avg_ms=0.100 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=16384 SKV=16384 H=8 D=64 avg_ms=1.091 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=16384 SKV=16384 H=8 D=64 avg_ms=1.032 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] WRAPPER S=150000 SKV=150000 H=8 D=64 avg_ms=91.376 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
[TILELANG_INDEXER] IMPL S=150000 SKV=150000 H=8 D=64 avg_ms=88.629 over 50
[KERNEL_SRC] No TVM runtime module found on kernel mqa_attn_return_logits_kernel
root@a1cb74468f35:~/TileLang#
If I increase heads to 16, it never finishes.
Isn't DeepSeek V3.2-Exp supposed to have 64 heads? Why does this kernel only work with 8?
Reproducible example code
The Python snippets:
Traceback
Expected behavior
No response
Additional context
No response
Hi @createthis on Blackwell, currently tilelang needs to explicitly copy data to tmem in order to use utcmma. Otherwise, it will fall back to mma, which can be much slower. On H100 it would run faster.
Automating this process is already in our roadmap.