Usage of persistent QK (rms) norm and performance vs. normal rmsnorm
Question
In persistent QK rms norm what is a real example use case where it beats regular path? https://github.com/flashinfer-ai/flashinfer/pull/1843 PR that introduce this feature (dispatch to either impl)
Context
@happierpig I have a small question about the implementation.
I micro benched in the real case of 2d or 2d choice in rmsnorm, 3d seems to not advance better than the 2d. For head_dim <= 256 (qwen3 dense = 128, qwen3 next = 256), 3d is slower.
When head_dim = 512, 3d is faster than jobs >= 32k. And head dim = 1024 and it's ~20% faster (not a real use case)?
Bench results
I think it's persistent warp when head_dim >> (so launch overhead is amortized) and there are enough jobs (here I say batch x heads = jobs) In a config like this below, (qwen 3 typical decode)
head_dim = 128
# sourced from https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json#L21
num_heads = 64
# sourced from https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json#L10
batch = 8 # jobs = 256
| batch | num_heads | jobs | head_dim | 2D mean (ms) | 3D mean (ms) | Δ% (3D−2D)/2D |
|---|---|---|---|---|---|---|
| 1 | 64 | 64 | 128 | 0.017 | 0.018 | +6% slower |
| 8 | 64 | 512 | 128 | 0.017 | 0.018 | +6% slower |
| 64 | 64 | 4096 | 128 | 0.017 | 0.018 | +6% slower |
| 256 | 64 | 16384 | 128 | 0.017 | 0.018 | +6% slower |
| 1024 | 64 | 65536 | 128 | 0.022 | 0.032 | +45% slower |
| 4096 | 64 | 262144 | 128 | 0.028 | 0.045 | +61% slower |
note
I have not test in sgl actual modelling path because: In SGL, the flashinfer version is a bit out of date, and seem to be some error on compile with flashinfer 0.4.1 SHA, and I can't seem to fix the numerous problem in compile with latest sha (some breaking change maybe, might be reason behind why flashinfer python and compiled kernel version not in sync). I don't want to compare with dispatch with python only (i.e through flashinfer_python (it's not a good experimental setup). https://github.com/sgl-project/sglang/blob/3f4cc0aff0166252bf319c94faa68326ad14dffb/sgl-kernel/CMakeLists.txt#L86
Repro
B200x8, TP 0
Here is my script how to use it:
head_dim = 128, num_heads = 64, typical decode batch ≈ 8 → jobs = 512.
python3 /sgl-workspace/sglang/bench_rmsnorm_proper --device cuda:0 --outfile ./dispatch_results.csv --iters 1000 --warmup 500 --no-contiguous
contig flag create noncontig tensor view (artificial), and measure cost of contig call to try to adjust (since in practice we wouldn't bother call 3d warp rmsnorm anyways
Perf script:
#!/usr/bin/env python3
"""
FlashInfer RMSNorm dispatch+inspect benchmark.
- Adds precise diagnostics: shapes, strides, contiguity, vector alignment check (vec_size = gcd(8, head_dim) for fp16),
explicit contiguous-copy timing, and kernel-only timing via torch.cuda.Event.
- Compares:
* 3D input (possibly non-contiguous)
* 3D input but forced contiguous (explicit .contiguous()) — measures copy cost
* 2D flattened input (contiguous)
- Disables GC during timed loops.
- Saves CSV with extra diagnostic columns.
Usage:
pip install flashinfer-python flashinfer-cubin
python3 flashinfer_dispatch_inspect_benchmark.py --device cuda:0 --quick --outfile ./inspect.csv --iters 200 --warmup 50
"""
import argparse, csv, time, statistics, gc, os, sys
from math import gcd
try:
import torch
except Exception as e:
print("ERROR: torch import failed:", e, file=sys.stderr); raise SystemExit(1)
try:
import flashinfer
from flashinfer import norm as fin_norm
except Exception as e:
print("ERROR: flashinfer import failed:", e, file=sys.stderr); raise SystemExit(1)
def make_inputs(batch, num_heads, head_dim, ndim, dtype=torch.float16, device="cuda:0", contiguous=True):
if ndim == 3:
x = torch.randn(batch, num_heads, head_dim, device=device, dtype=dtype)
w = torch.randn(head_dim, device=device, dtype=dtype)
else:
hidden = num_heads * head_dim
x = torch.randn(batch, hidden, device=device, dtype=dtype)
w = torch.randn(hidden, device=device, dtype=dtype)
if not contiguous:
# create a likely non-contiguous view; for 3D, try transpose-trick; for 2D, transpose
if ndim == 3:
x = x.transpose(1,2) # shape (batch, head_dim, num_heads)
x = x[:, :head_dim, :].transpose(1,2) # back to (batch, num_heads, head_dim) but not guaranteed contiguous
else:
x = x.t() # non-contiguous view
return x, w
def measure_contiguous_copy_time(x):
# return time in ms to make x.contiguous()
if not torch.cuda.is_available():
return None
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
y = x.contiguous()
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e), y
def kernel_time_rmsnorm(inp, weight, out=None, enable_pdl=False, iters=50, warmup=10):
# Measure kernel-only time using CUDA events; returns mean ms, std ms
if not torch.cuda.is_available():
# fallback to CPU timer for completeness (less accurate)
def call():
return fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
# warmup
for _ in range(warmup):
call()
times = []
for _ in range(iters):
t0 = time.perf_counter()
call()
t1 = time.perf_counter()
times.append((t1-t0)*1000.0)
return statistics.mean(times), statistics.stdev(times)
# GPU path: use events and explicit sync
torch.cuda.synchronize()
ev_start = torch.cuda.Event(enable_timing=True)
ev_end = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(warmup):
fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
times = []
for _ in range(iters):
ev_start.record()
fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
ev_end.record()
torch.cuda.synchronize()
times.append(ev_start.elapsed_time(ev_end))
return statistics.mean(times), statistics.stdev(times)
def inspect_input(x, name):
s = {
"name": name,
"shape": tuple(x.shape),
"stride": tuple(x.stride()),
"is_contiguous": bool(x.is_contiguous()),
"device": str(x.device),
"dtype": str(x.dtype),
}
# compute fp16 vector alignment suggestion (vec_size = gcd(8, head_dim) for fp16)
if x.ndim >= 1:
d = x.shape[-1]
s["last_dim"] = d
s["vec_align"] = int(gcd(8, d))
return s
def run_case(batch, num_heads, head_dim, ndim, device, iters, warmup, enable_pdl, use_out, contiguous):
dtype = torch.float16
# prepare inputs (possibly non-contiguous)
x3, w3 = make_inputs(batch, num_heads, head_dim, ndim, dtype=dtype, device=device, contiguous=contiguous)
# inspect
info = inspect_input(x3, "original")
# measure contiguous copy cost if it's not contiguous
copy_ms = None
y_contig = None
if not x3.is_contiguous():
copy_ms, y_contig = measure_contiguous_copy_time(x3)
# 3D run as-is (if ndim==2, this is actually 2D)
out_tensor = None
if use_out:
out_tensor = torch.empty_like(x3)
mean3_ms, std3_ms = kernel_time_rmsnorm(x3, w3 if ndim==3 else w3, out=out_tensor, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
# if we measured contiguous copy, also measure running on contig version
if y_contig is not None:
# measure kernel on contiguous copy
out_tensor_contig = torch.empty_like(y_contig) if use_out else None
mean3_contig_ms, std3_contig_ms = kernel_time_rmsnorm(y_contig, w3 if ndim==3 else w3, out=out_tensor_contig, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
else:
mean3_contig_ms, std3_contig_ms = mean3_ms, std3_ms
# Flatten to 2D version (explicit contiguous flatten)
if ndim == 3:
hidden = num_heads * head_dim
x2 = x3.view(-1, hidden).contiguous()
w2 = w3.repeat(num_heads) if False else torch.randn(hidden, device=device, dtype=dtype) # weight shape differs; but flashinfer.rmsnorm expects last-dim weight.
# For a fair comparison, we build a proper 2D weight matching flattened hidden size:
# However, original 2D RMSNorm expects weight length hidden. To compare kernel behavior,
# we'll measure a 2D RMSNorm with weight length equal to hidden.
w2 = torch.randn(hidden, device=device, dtype=dtype)
else:
x2 = x3.contiguous()
w2 = w3
out2 = torch.empty_like(x2) if use_out else None
mean2_ms, std2_ms = kernel_time_rmsnorm(x2, w2, out=out2, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
# explicit contiguous cost if we forced contig above
contig_cost_ms = None
if y_contig is not None:
contig_cost_ms = copy_ms
# jobs (what the warp kernel sees): total_tokens * num_heads if 3D; if 2D flattened we used view(-1, hidden)
jobs = (x3.shape[0] * x3.shape[1]) if (ndim == 3 and x3.ndim == 3) else x2.shape[0]
return {
"batch": batch,
"num_heads": num_heads,
"head_dim": head_dim,
"input_ndim": ndim,
"is_contiguous_original": info["is_contiguous"],
"orig_shape": info["shape"],
"orig_stride": info["stride"],
"vec_align": info.get("vec_align"),
"copy_ms": contig_cost_ms,
"mean3_ms": mean3_ms,
"std3_ms": std3_ms,
"mean3_contig_ms": mean3_contig_ms,
"std3_contig_ms": std3_contig_ms,
"mean2_ms": mean2_ms,
"std2_ms": std2_ms,
"jobs": jobs,
}
def recommend_thresholds(results):
# same as before: find where mean3 < mean2*0.95
grouped = {}
for r in results:
key = (r["head_dim"], r["jobs"])
grouped.setdefault(key, {})[r["input_ndim"]] = r
per_head = {}
for (hd, jobs), val in grouped.items():
if 2 in val and 3 in val:
mean2 = val[2]["mean2_ms"]; mean3 = val[3]["mean3_ms"]
per_head.setdefault(hd, []).append((jobs, mean2, mean3))
recommendations = []
for hd, rows in per_head.items():
rows.sort(key=lambda x: x[0])
chosen = None
for jobs, mean2, mean3 in rows:
if mean3 < mean2 * 0.95:
chosen = jobs
break
recommendations.append({"head_dim": hd, "jobs_min_recommend": chosen if chosen is not None else float("inf")})
return recommendations
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--outfile", default="./inspect_dispatch_results.csv")
parser.add_argument("--iters", type=int, default=50)
parser.add_argument("--warmup", type=int, default=10)
parser.add_argument("--quick", action="store_true")
parser.add_argument("--enable_pdl", action="store_true")
parser.add_argument("--use_out", action="store_true")
parser.add_argument("--contiguous", dest="contiguous", action="store_true")
parser.add_argument("--no-contiguous", dest="contiguous", action="store_false")
parser.set_defaults(contiguous=True)
args = parser.parse_args()
device = args.device
if "cuda" in device and not torch.cuda.is_available():
raise SystemExit("CUDA requested but not available on this machine")
if args.quick:
grid = [
{"batch":1, "num_heads":32, "head_dim":128, "ndim":3},
{"batch":64, "num_heads":32, "head_dim":128, "ndim":3},
{"batch":1024, "num_heads":32, "head_dim":128, "ndim":3},
{"batch":1, "num_heads":1, "head_dim":4096, "ndim":2},
{"batch":1024, "num_heads":1, "head_dim":4096, "ndim":2},
# add head_dim 256 test
{"batch":64, "num_heads":16, "head_dim":256, "ndim":3},
{"batch":1024, "num_heads":16, "head_dim":256, "ndim":3},
]
else:
batches = [1,8,64,256,1024,4096]
num_heads_list = [4,8,16,32]
head_dims = [16,32,64,128,256,512,1024]
grid = []
for b in batches:
for nh in num_heads_list:
for hd in head_dims:
grid.append({"batch": b, "num_heads": nh, "head_dim": hd, "ndim":3})
for b in batches:
for hidden in [256,1024,4096,8192]:
grid.append({"batch": b, "num_heads": 1, "head_dim": hidden, "ndim":2})
results = []
total = len(grid); idx = 0
for cfg in grid:
idx += 1
batch = cfg["batch"]; num_heads = cfg["num_heads"]; head_dim = cfg["head_dim"]; ndim = cfg["ndim"]
print(f"[{idx}/{total}] Running {batch=} {num_heads=} {head_dim=} {ndim=} contiguous={args.contiguous}", flush=True)
try:
# Prevent GC during each run (the kernel_time_rmsnorm also disables GC in tight loop)
was_gc = gc.isenabled()
if was_gc:
gc.disable()
res = run_case(batch, num_heads, head_dim, ndim, args.device, args.iters, args.warmup, args.enable_pdl, args.use_out, args.contiguous)
if was_gc:
gc.enable()
print(" -> jobs=%d 2D mean=%.3fms 3D mean=%.3fms 3D(contig)=%.3fms copy_ms=%s vec_align=%s strides=%s" %
(res["jobs"], res["mean2_ms"], res["mean3_ms"], res["mean3_contig_ms"], str(res["copy_ms"]), str(res["vec_align"]), str(res["orig_stride"])),
flush=True)
results.append(res)
except Exception as e:
print(" ERROR for config", cfg, e, flush=True)
# save CSV
keys = ["batch","num_heads","head_dim","input_ndim","is_contiguous_original","orig_shape","orig_stride","vec_align","copy_ms","mean3_ms","std3_ms","mean3_contig_ms","std3_contig_ms","mean2_ms","std2_ms","jobs"]
with open(args.outfile, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=keys)
writer.writeheader()
for r in results:
# flatten tuple fields to strings
r_out = {k: r.get(k, "") for k in keys}
if isinstance(r_out.get("orig_shape"), tuple): r_out["orig_shape"] = str(r_out["orig_shape"])
if isinstance(r_out.get("orig_stride"), tuple): r_out["orig_stride"] = str(r_out["orig_stride"])
writer.writerow(r_out)
recs = recommend_thresholds(results)
print("Recommendations (head_dim -> recommended minimum jobs for 3D warp kernel to win by >=5%):")
for r in recs:
print(r)
print("Saved CSV to", args.outfile)
if __name__ == '__main__':
main()
The results: (quite obviously, the errors are just try non-contiguous on 2d rmsnorm (expected)) https://gist.github.com/vincentzed/f4b00d40897ad38fb4136e546d901bfc
And csv https://gist.github.com/vincentzed/d43074a3ed390a3bec679a15cfa54fe0
In persistent QK rms norm what is a real example use case where it beats regular path?
@happierpig can you answer this question? I think it was used in your vllm integration.
In SGL, the flashinfer version is a bit out of date, and seem to be some error on compile with flashinfer 0.4.1 SHA, and I can't seem to fix the numerous problem in compile with latest sha (some breaking change maybe, might be reason behind why flashinfer python and compiled kernel version not in sync).
Early this year sgl ported some flashinfer kernels (including norm) to sgl-kernel with standalone packaging because flashinfer moved to JIT then, I don't think it's still necessary at this moment and I prefer to use flashinfer python APIs directly instead of source level dependency. As you mentioned we have lots of refactor since then.
I don't want to compare with dispatch with python only (i.e through flashinfer_python (it's not a good experimental setup). https://github.com/sgl-project/sglang/blob/3f4cc0aff0166252bf319c94faa68326ad14dffb/sgl-kernel/CMakeLists.txt#L86
I don't understand "it's not a good experimental setup", can you explain more?
I don't understand "it's not a good experimental setup", can you explain more? Based on my understanding of sgl kernel, in order to use it at runtime I can have two option
- Dispatch to flashinfer_python in layernorm path in sgl
- Recompile with sgl kernel with updated flashinfer sha (the one in cmake list)
Since the end-level integration does not involve rewriting wrappers in sgl (all it requires is recompile flashinfer with latest sha), then I want to test it by compile wheel and run sgl kernel instead. When upgrading flashinfer wheel sha it has some problem, haven't had time to look into it (else we can test this empirically, more easily)
Early this year sgl ported some flashinfer kernels (including norm) to sgl-kernel with standalone packaging because flashinfer moved to JIT then, I don't think it's still necessary at this moment and I prefer to use flashinfer python APIs directly instead of source level dependency. As you mentioned we have lots of refactor since then.
sgl has a ~ little bit of usage of flashinfer python wrapper, but some things (like fused rms norm / rms norm) are still in wheel evidently
When upgrading flashinfer wheel sha it has some problem
Yes because we changed the API and it will take a lot of effort updating that on sglang side (we don't guarantee the stability of C++ APIs).
but some things (like fused rms norm / rms norm) are still in wheel evidently
sglang used to rely on flashinfer's python APIs for these (rmsnorm, etc) functions when we release aot wheel only. I don't think it's still necessary as we have stable release pipelines for both aot and jit wheels at flashinfer-side now.
@vincentzed Sorry for the late reply. Here are things in my mind:
- This 3D implementation is not for performance advantage of persistent. Instead, it is to support non-contiguous input. I havn't tested on Blackwell, but running your script on H200 gives comparable numbers:
batch,num_heads,head_dim,input_ndim,is_contiguous_original,orig_shape,orig_stride,vec_align,copy_ms,mean3_ms,std3_ms,mean3_contig_ms,std3_contig_ms,mean2_ms,std2_ms,jobs
1,1,256,2,True,"(1, 256)","(256, 1)",8,,0.005981000534163133,1.962612882453122e-05,0.005981000534163133,1.962612882453122e-05,0.0059850137545446195,1.1544483438337546e-05,1
1,1,1024,2,True,"(1, 1024)","(1024, 1)",8,,0.006034079931298585,1.8040962399209137e-05,0.006034079931298585,1.8040962399209137e-05,0.006040765351815367,1.156956099460412e-05,1
1,1,4096,2,True,"(1, 4096)","(4096, 1)",8,,0.006271360305756527,4.643516107634431e-05,0.006271360305756527,4.643516107634431e-05,0.006258204057140081,9.068072653489733e-06,1
1,1,8192,2,True,"(1, 8192)","(8192, 1)",8,,0.006633921931978511,1.2879135969290683e-05,0.006633921931978511,1.2879135969290683e-05,0.006636676068551759,1.6487512540588818e-05,1
8,1,256,2,True,"(8, 256)","(256, 1)",8,,0.005904900624492044,1.5947473415889564e-05,0.005904900624492044,1.5947473415889564e-05,0.005900192811473293,1.8185705957831912e-05,8
8,1,1024,2,True,"(8, 1024)","(1024, 1)",8,,0.005969399579423343,3.048738023746512e-05,0.005969399579423343,3.048738023746512e-05,0.005935703212387159,1.3387354406690704e-05,8
8,1,4096,2,True,"(8, 4096)","(4096, 1)",8,,0.006274061337838674,1.227426043422927e-05,0.006274061337838674,1.227426043422927e-05,0.006288149234939134,1.9953791749279674e-05,8
8,1,8192,2,True,"(8, 8192)","(8192, 1)",8,,0.006693300180084355,1.9129815377125664e-05,0.006693300180084355,1.9129815377125664e-05,0.00669569863868943,8.8176767152521e-06,8
64,1,256,2,True,"(64, 256)","(256, 1)",8,,0.005978307708902481,4.624543841693019e-05,0.005978307708902481,4.624543841693019e-05,0.0060078992061522105,9.440892701249972e-06,64
- Non-contiguous is common when fusing qkv projection to squeeze the GEMM performance. In this case, additional
.contigous()is called to make existing 2D kernel functional, which leads to 2x more latency. Here is a simple example: https://gist.github.com/happierpig/8a70d7d49243a25b0ee11049a4bd638e
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=1, default_ms=0.011ms, customized_ms=0.009ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=16, default_ms=0.022ms, customized_ms=0.009ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=64, default_ms=0.022ms, customized_ms=0.010ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=128, default_ms=0.028ms, customized_ms=0.010ms
- Anyhow, this modification should be transparent to the Python-side API.