flashinfer icon indicating copy to clipboard operation
flashinfer copied to clipboard

Usage of persistent QK (rms) norm and performance vs. normal rmsnorm

Open vincentzed opened this issue 2 months ago • 5 comments

Question

In persistent QK rms norm what is a real example use case where it beats regular path? https://github.com/flashinfer-ai/flashinfer/pull/1843 PR that introduce this feature (dispatch to either impl)

Context

@happierpig I have a small question about the implementation.

I micro benched in the real case of 2d or 2d choice in rmsnorm, 3d seems to not advance better than the 2d. For head_dim <= 256 (qwen3 dense = 128, qwen3 next = 256), 3d is slower.

When head_dim = 512, 3d is faster than jobs >= 32k. And head dim = 1024 and it's ~20% faster (not a real use case)?

Bench results

I think it's persistent warp when head_dim >> (so launch overhead is amortized) and there are enough jobs (here I say batch x heads = jobs) In a config like this below, (qwen 3 typical decode)

head_dim = 128
# sourced from https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json#L21
num_heads = 64
# sourced from https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json#L10
batch = 8  #  jobs = 256
batch num_heads jobs head_dim 2D mean (ms) 3D mean (ms) Δ% (3D−2D)/2D
1 64 64 128 0.017 0.018 +6% slower
8 64 512 128 0.017 0.018 +6% slower
64 64 4096 128 0.017 0.018 +6% slower
256 64 16384 128 0.017 0.018 +6% slower
1024 64 65536 128 0.022 0.032 +45% slower
4096 64 262144 128 0.028 0.045 +61% slower

note

I have not test in sgl actual modelling path because: In SGL, the flashinfer version is a bit out of date, and seem to be some error on compile with flashinfer 0.4.1 SHA, and I can't seem to fix the numerous problem in compile with latest sha (some breaking change maybe, might be reason behind why flashinfer python and compiled kernel version not in sync). I don't want to compare with dispatch with python only (i.e through flashinfer_python (it's not a good experimental setup). https://github.com/sgl-project/sglang/blob/3f4cc0aff0166252bf319c94faa68326ad14dffb/sgl-kernel/CMakeLists.txt#L86

vincentzed avatar Oct 27 '25 00:10 vincentzed

Repro

B200x8, TP 0 Here is my script how to use it: head_dim = 128, num_heads = 64, typical decode batch ≈ 8 → jobs = 512.

python3 /sgl-workspace/sglang/bench_rmsnorm_proper --device cuda:0 --outfile ./dispatch_results.csv --iters 1000  --warmup 500  --no-contiguous

contig flag create noncontig tensor view (artificial), and measure cost of contig call to try to adjust (since in practice we wouldn't bother call 3d warp rmsnorm anyways

Perf script:

#!/usr/bin/env python3
"""
FlashInfer RMSNorm dispatch+inspect benchmark.

- Adds precise diagnostics: shapes, strides, contiguity, vector alignment check (vec_size = gcd(8, head_dim) for fp16),
  explicit contiguous-copy timing, and kernel-only timing via torch.cuda.Event.
- Compares:
    * 3D input (possibly non-contiguous)
    * 3D input but forced contiguous (explicit .contiguous()) — measures copy cost
    * 2D flattened input (contiguous)
- Disables GC during timed loops.
- Saves CSV with extra diagnostic columns.

Usage:
    pip install flashinfer-python flashinfer-cubin
    python3 flashinfer_dispatch_inspect_benchmark.py --device cuda:0 --quick --outfile ./inspect.csv --iters 200 --warmup 50
"""
import argparse, csv, time, statistics, gc, os, sys
from math import gcd
try:
    import torch
except Exception as e:
    print("ERROR: torch import failed:", e, file=sys.stderr); raise SystemExit(1)

try:
    import flashinfer
    from flashinfer import norm as fin_norm
except Exception as e:
    print("ERROR: flashinfer import failed:", e, file=sys.stderr); raise SystemExit(1)

def make_inputs(batch, num_heads, head_dim, ndim, dtype=torch.float16, device="cuda:0", contiguous=True):
    if ndim == 3:
        x = torch.randn(batch, num_heads, head_dim, device=device, dtype=dtype)
        w = torch.randn(head_dim, device=device, dtype=dtype)
    else:
        hidden = num_heads * head_dim
        x = torch.randn(batch, hidden, device=device, dtype=dtype)
        w = torch.randn(hidden, device=device, dtype=dtype)
    if not contiguous:
        # create a likely non-contiguous view; for 3D, try transpose-trick; for 2D, transpose
        if ndim == 3:
            x = x.transpose(1,2)  # shape (batch, head_dim, num_heads)
            x = x[:, :head_dim, :].transpose(1,2)  # back to (batch, num_heads, head_dim) but not guaranteed contiguous
        else:
            x = x.t()  # non-contiguous view
    return x, w

def measure_contiguous_copy_time(x):
    # return time in ms to make x.contiguous()
    if not torch.cuda.is_available():
        return None
    torch.cuda.synchronize()
    s = torch.cuda.Event(enable_timing=True)
    e = torch.cuda.Event(enable_timing=True)
    s.record()
    y = x.contiguous()
    e.record()
    torch.cuda.synchronize()
    return s.elapsed_time(e), y

def kernel_time_rmsnorm(inp, weight, out=None, enable_pdl=False, iters=50, warmup=10):
    # Measure kernel-only time using CUDA events; returns mean ms, std ms
    if not torch.cuda.is_available():
        # fallback to CPU timer for completeness (less accurate)
        def call():
            return fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
        # warmup
        for _ in range(warmup):
            call()
        times = []
        for _ in range(iters):
            t0 = time.perf_counter()
            call()
            t1 = time.perf_counter()
            times.append((t1-t0)*1000.0)
        return statistics.mean(times), statistics.stdev(times)
    # GPU path: use events and explicit sync
    torch.cuda.synchronize()
    ev_start = torch.cuda.Event(enable_timing=True)
    ev_end = torch.cuda.Event(enable_timing=True)
    # warmup
    for _ in range(warmup):
        fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
    times = []
    for _ in range(iters):
        ev_start.record()
        fin_norm.rmsnorm(inp, weight, eps=1e-6, out=out, enable_pdl=enable_pdl)
        ev_end.record()
        torch.cuda.synchronize()
        times.append(ev_start.elapsed_time(ev_end))
    return statistics.mean(times), statistics.stdev(times)

def inspect_input(x, name):
    s = {
        "name": name,
        "shape": tuple(x.shape),
        "stride": tuple(x.stride()),
        "is_contiguous": bool(x.is_contiguous()),
        "device": str(x.device),
        "dtype": str(x.dtype),
    }
    # compute fp16 vector alignment suggestion (vec_size = gcd(8, head_dim) for fp16)
    if x.ndim >= 1:
        d = x.shape[-1]
        s["last_dim"] = d
        s["vec_align"] = int(gcd(8, d))
    return s

def run_case(batch, num_heads, head_dim, ndim, device, iters, warmup, enable_pdl, use_out, contiguous):
    dtype = torch.float16
    # prepare inputs (possibly non-contiguous)
    x3, w3 = make_inputs(batch, num_heads, head_dim, ndim, dtype=dtype, device=device, contiguous=contiguous)
    # inspect
    info = inspect_input(x3, "original")
    # measure contiguous copy cost if it's not contiguous
    copy_ms = None
    y_contig = None
    if not x3.is_contiguous():
        copy_ms, y_contig = measure_contiguous_copy_time(x3)
    # 3D run as-is (if ndim==2, this is actually 2D)
    out_tensor = None
    if use_out:
        out_tensor = torch.empty_like(x3)
    mean3_ms, std3_ms = kernel_time_rmsnorm(x3, w3 if ndim==3 else w3, out=out_tensor, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
    # if we measured contiguous copy, also measure running on contig version
    if y_contig is not None:
        # measure kernel on contiguous copy
        out_tensor_contig = torch.empty_like(y_contig) if use_out else None
        mean3_contig_ms, std3_contig_ms = kernel_time_rmsnorm(y_contig, w3 if ndim==3 else w3, out=out_tensor_contig, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
    else:
        mean3_contig_ms, std3_contig_ms = mean3_ms, std3_ms
    # Flatten to 2D version (explicit contiguous flatten)
    if ndim == 3:
        hidden = num_heads * head_dim
        x2 = x3.view(-1, hidden).contiguous()
        w2 = w3.repeat(num_heads) if False else torch.randn(hidden, device=device, dtype=dtype)  # weight shape differs; but flashinfer.rmsnorm expects last-dim weight.
        # For a fair comparison, we build a proper 2D weight matching flattened hidden size:
        # However, original 2D RMSNorm expects weight length hidden. To compare kernel behavior,
        # we'll measure a 2D RMSNorm with weight length equal to hidden.
        w2 = torch.randn(hidden, device=device, dtype=dtype)
    else:
        x2 = x3.contiguous()
        w2 = w3
    out2 = torch.empty_like(x2) if use_out else None
    mean2_ms, std2_ms = kernel_time_rmsnorm(x2, w2, out=out2, enable_pdl=enable_pdl, iters=iters, warmup=warmup)
    # explicit contiguous cost if we forced contig above
    contig_cost_ms = None
    if y_contig is not None:
        contig_cost_ms = copy_ms
    # jobs (what the warp kernel sees): total_tokens * num_heads if 3D; if 2D flattened we used view(-1, hidden)
    jobs = (x3.shape[0] * x3.shape[1]) if (ndim == 3 and x3.ndim == 3) else x2.shape[0]
    return {
        "batch": batch,
        "num_heads": num_heads,
        "head_dim": head_dim,
        "input_ndim": ndim,
        "is_contiguous_original": info["is_contiguous"],
        "orig_shape": info["shape"],
        "orig_stride": info["stride"],
        "vec_align": info.get("vec_align"),
        "copy_ms": contig_cost_ms,
        "mean3_ms": mean3_ms,
        "std3_ms": std3_ms,
        "mean3_contig_ms": mean3_contig_ms,
        "std3_contig_ms": std3_contig_ms,
        "mean2_ms": mean2_ms,
        "std2_ms": std2_ms,
        "jobs": jobs,
    }

def recommend_thresholds(results):
    # same as before: find where mean3 < mean2*0.95
    grouped = {}
    for r in results:
        key = (r["head_dim"], r["jobs"])
        grouped.setdefault(key, {})[r["input_ndim"]] = r
    per_head = {}
    for (hd, jobs), val in grouped.items():
        if 2 in val and 3 in val:
            mean2 = val[2]["mean2_ms"]; mean3 = val[3]["mean3_ms"]
            per_head.setdefault(hd, []).append((jobs, mean2, mean3))
    recommendations = []
    for hd, rows in per_head.items():
        rows.sort(key=lambda x: x[0])
        chosen = None
        for jobs, mean2, mean3 in rows:
            if mean3 < mean2 * 0.95:
                chosen = jobs
                break
        recommendations.append({"head_dim": hd, "jobs_min_recommend": chosen if chosen is not None else float("inf")})
    return recommendations

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cuda:0")
    parser.add_argument("--outfile", default="./inspect_dispatch_results.csv")
    parser.add_argument("--iters", type=int, default=50)
    parser.add_argument("--warmup", type=int, default=10)
    parser.add_argument("--quick", action="store_true")
    parser.add_argument("--enable_pdl", action="store_true")
    parser.add_argument("--use_out", action="store_true")
    parser.add_argument("--contiguous", dest="contiguous", action="store_true")
    parser.add_argument("--no-contiguous", dest="contiguous", action="store_false")
    parser.set_defaults(contiguous=True)
    args = parser.parse_args()

    device = args.device
    if "cuda" in device and not torch.cuda.is_available():
        raise SystemExit("CUDA requested but not available on this machine")

    if args.quick:
        grid = [
            {"batch":1, "num_heads":32, "head_dim":128, "ndim":3},
            {"batch":64, "num_heads":32, "head_dim":128, "ndim":3},
            {"batch":1024, "num_heads":32, "head_dim":128, "ndim":3},
            {"batch":1, "num_heads":1, "head_dim":4096, "ndim":2},
            {"batch":1024, "num_heads":1, "head_dim":4096, "ndim":2},
            # add head_dim 256 test
            {"batch":64, "num_heads":16, "head_dim":256, "ndim":3},
            {"batch":1024, "num_heads":16, "head_dim":256, "ndim":3},
        ]
    else:
        batches = [1,8,64,256,1024,4096]
        num_heads_list = [4,8,16,32]
        head_dims = [16,32,64,128,256,512,1024]
        grid = []
        for b in batches:
            for nh in num_heads_list:
                for hd in head_dims:
                    grid.append({"batch": b, "num_heads": nh, "head_dim": hd, "ndim":3})
        for b in batches:
            for hidden in [256,1024,4096,8192]:
                grid.append({"batch": b, "num_heads": 1, "head_dim": hidden, "ndim":2})

    results = []
    total = len(grid); idx = 0
    for cfg in grid:
        idx += 1
        batch = cfg["batch"]; num_heads = cfg["num_heads"]; head_dim = cfg["head_dim"]; ndim = cfg["ndim"]
        print(f"[{idx}/{total}] Running {batch=} {num_heads=} {head_dim=} {ndim=} contiguous={args.contiguous}", flush=True)
        try:
            # Prevent GC during each run (the kernel_time_rmsnorm also disables GC in tight loop)
            was_gc = gc.isenabled()
            if was_gc:
                gc.disable()
            res = run_case(batch, num_heads, head_dim, ndim, args.device, args.iters, args.warmup, args.enable_pdl, args.use_out, args.contiguous)
            if was_gc:
                gc.enable()
            print("  -> jobs=%d 2D mean=%.3fms 3D mean=%.3fms 3D(contig)=%.3fms copy_ms=%s vec_align=%s strides=%s" %
                  (res["jobs"], res["mean2_ms"], res["mean3_ms"], res["mean3_contig_ms"], str(res["copy_ms"]), str(res["vec_align"]), str(res["orig_stride"])),
                  flush=True)
            results.append(res)
        except Exception as e:
            print("  ERROR for config", cfg, e, flush=True)

    # save CSV
    keys = ["batch","num_heads","head_dim","input_ndim","is_contiguous_original","orig_shape","orig_stride","vec_align","copy_ms","mean3_ms","std3_ms","mean3_contig_ms","std3_contig_ms","mean2_ms","std2_ms","jobs"]
    with open(args.outfile, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=keys)
        writer.writeheader()
        for r in results:
            # flatten tuple fields to strings
            r_out = {k: r.get(k, "") for k in keys}
            if isinstance(r_out.get("orig_shape"), tuple): r_out["orig_shape"] = str(r_out["orig_shape"])
            if isinstance(r_out.get("orig_stride"), tuple): r_out["orig_stride"] = str(r_out["orig_stride"])
            writer.writerow(r_out)

    recs = recommend_thresholds(results)
    print("Recommendations (head_dim -> recommended minimum jobs for 3D warp kernel to win by >=5%):")
    for r in recs:
        print(r)
    print("Saved CSV to", args.outfile)

if __name__ == '__main__':
    main()

The results: (quite obviously, the errors are just try non-contiguous on 2d rmsnorm (expected)) https://gist.github.com/vincentzed/f4b00d40897ad38fb4136e546d901bfc

And csv https://gist.github.com/vincentzed/d43074a3ed390a3bec679a15cfa54fe0

vincentzed avatar Oct 27 '25 00:10 vincentzed

In persistent QK rms norm what is a real example use case where it beats regular path?

@happierpig can you answer this question? I think it was used in your vllm integration.

In SGL, the flashinfer version is a bit out of date, and seem to be some error on compile with flashinfer 0.4.1 SHA, and I can't seem to fix the numerous problem in compile with latest sha (some breaking change maybe, might be reason behind why flashinfer python and compiled kernel version not in sync).

Early this year sgl ported some flashinfer kernels (including norm) to sgl-kernel with standalone packaging because flashinfer moved to JIT then, I don't think it's still necessary at this moment and I prefer to use flashinfer python APIs directly instead of source level dependency. As you mentioned we have lots of refactor since then.

I don't want to compare with dispatch with python only (i.e through flashinfer_python (it's not a good experimental setup). https://github.com/sgl-project/sglang/blob/3f4cc0aff0166252bf319c94faa68326ad14dffb/sgl-kernel/CMakeLists.txt#L86

I don't understand "it's not a good experimental setup", can you explain more?

yzh119 avatar Oct 27 '25 05:10 yzh119

I don't understand "it's not a good experimental setup", can you explain more? Based on my understanding of sgl kernel, in order to use it at runtime I can have two option

  1. Dispatch to flashinfer_python in layernorm path in sgl
  2. Recompile with sgl kernel with updated flashinfer sha (the one in cmake list)

Since the end-level integration does not involve rewriting wrappers in sgl (all it requires is recompile flashinfer with latest sha), then I want to test it by compile wheel and run sgl kernel instead. When upgrading flashinfer wheel sha it has some problem, haven't had time to look into it (else we can test this empirically, more easily)

Early this year sgl ported some flashinfer kernels (including norm) to sgl-kernel with standalone packaging because flashinfer moved to JIT then, I don't think it's still necessary at this moment and I prefer to use flashinfer python APIs directly instead of source level dependency. As you mentioned we have lots of refactor since then.

sgl has a ~ little bit of usage of flashinfer python wrapper, but some things (like fused rms norm / rms norm) are still in wheel evidently

vincentzed avatar Oct 27 '25 19:10 vincentzed

When upgrading flashinfer wheel sha it has some problem

Yes because we changed the API and it will take a lot of effort updating that on sglang side (we don't guarantee the stability of C++ APIs).

but some things (like fused rms norm / rms norm) are still in wheel evidently

sglang used to rely on flashinfer's python APIs for these (rmsnorm, etc) functions when we release aot wheel only. I don't think it's still necessary as we have stable release pipelines for both aot and jit wheels at flashinfer-side now.

yzh119 avatar Oct 27 '25 19:10 yzh119

@vincentzed Sorry for the late reply. Here are things in my mind:

  1. This 3D implementation is not for performance advantage of persistent. Instead, it is to support non-contiguous input. I havn't tested on Blackwell, but running your script on H200 gives comparable numbers:
batch,num_heads,head_dim,input_ndim,is_contiguous_original,orig_shape,orig_stride,vec_align,copy_ms,mean3_ms,std3_ms,mean3_contig_ms,std3_contig_ms,mean2_ms,std2_ms,jobs
1,1,256,2,True,"(1, 256)","(256, 1)",8,,0.005981000534163133,1.962612882453122e-05,0.005981000534163133,1.962612882453122e-05,0.0059850137545446195,1.1544483438337546e-05,1
1,1,1024,2,True,"(1, 1024)","(1024, 1)",8,,0.006034079931298585,1.8040962399209137e-05,0.006034079931298585,1.8040962399209137e-05,0.006040765351815367,1.156956099460412e-05,1
1,1,4096,2,True,"(1, 4096)","(4096, 1)",8,,0.006271360305756527,4.643516107634431e-05,0.006271360305756527,4.643516107634431e-05,0.006258204057140081,9.068072653489733e-06,1
1,1,8192,2,True,"(1, 8192)","(8192, 1)",8,,0.006633921931978511,1.2879135969290683e-05,0.006633921931978511,1.2879135969290683e-05,0.006636676068551759,1.6487512540588818e-05,1
8,1,256,2,True,"(8, 256)","(256, 1)",8,,0.005904900624492044,1.5947473415889564e-05,0.005904900624492044,1.5947473415889564e-05,0.005900192811473293,1.8185705957831912e-05,8
8,1,1024,2,True,"(8, 1024)","(1024, 1)",8,,0.005969399579423343,3.048738023746512e-05,0.005969399579423343,3.048738023746512e-05,0.005935703212387159,1.3387354406690704e-05,8
8,1,4096,2,True,"(8, 4096)","(4096, 1)",8,,0.006274061337838674,1.227426043422927e-05,0.006274061337838674,1.227426043422927e-05,0.006288149234939134,1.9953791749279674e-05,8
8,1,8192,2,True,"(8, 8192)","(8192, 1)",8,,0.006693300180084355,1.9129815377125664e-05,0.006693300180084355,1.9129815377125664e-05,0.00669569863868943,8.8176767152521e-06,8
64,1,256,2,True,"(64, 256)","(256, 1)",8,,0.005978307708902481,4.624543841693019e-05,0.005978307708902481,4.624543841693019e-05,0.0060078992061522105,9.440892701249972e-06,64
  1. Non-contiguous is common when fusing qkv projection to squeeze the GEMM performance. In this case, additional .contigous() is called to make existing 2D kernel functional, which leads to 2x more latency. Here is a simple example: https://gist.github.com/happierpig/8a70d7d49243a25b0ee11049a4bd638e
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=1, default_ms=0.011ms, customized_ms=0.009ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=16, default_ms=0.022ms, customized_ms=0.009ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=64, default_ms=0.022ms, customized_ms=0.010ms
num_kv_heads=8, num_qo_heads=64, head_dim=128, batch_size=128, default_ms=0.028ms, customized_ms=0.010ms
  1. Anyhow, this modification should be transparent to the Python-side API.

happierpig avatar Nov 05 '25 01:11 happierpig