float8 inference weight-only quant should map to a fused kernel or explain why not
Today, Float8WeightOnlyConfig maps to a reference implementation of weight-only quant, which dequantized the tensor and then runs a high precision gemm: https://github.com/pytorch/ao/blob/ba3ac9f2f6117ba35ff28fbb8811f61ad992dfcf/torchao/quantization/quantize_/workflows/float8/float8_tensor.py#L392
Users have reported confusion about this, we should either clearly explain that no speedup is expected or map to a fast kernel.
Hi, I ran torchao's llama benchmark script with Float8WeightOnlyConfig on RTX 5090, and there's also abnormally high peak memory usage with this config, along with the initially reported slowdown:
20251108230606, tok/s=104.68, tok/s_decode=105.73, ttft=0.0186, mem/s=1571.23 GB/s, peak_mem=16.30 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20251108230802, tok/s=160.55, tok/s_decode=169.80, ttft=0.0675, mem/s=1204.97 GB/s, peak_mem= 9.21 GB, model_size= 7.51 GB quant: float8dq-tensor, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8dq-tensor --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
20251108231236, tok/s= 7.44, tok/s_decode= 7.48, ttft=0.1442, mem/s= 55.91 GB/s, peak_mem=26.00 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
I'm trying to raise a PR now that I've been able to reproduce the behavior with torchao's benchmarking code instead of my own code.
this is what we were doing before: https://github.com/pytorch/ao/blob/ba3ac9f2f6117ba35ff28fbb8811f61ad992dfcf/torchao/dtypes/floatx/float8_layout.py#L391
cc @drisspg @jainapurva can you comment on why float8 weight only quant is doing fallback? is this used?
AFAIK there exists goes in inductor that can do prologue fusion and actually get speed ups. For decode size with small activations and large weights this can be faster if compiled.
That being said I have no script or benchmark showing this but my recollection of @eellison work
@drisspg do you mean prologue fusion ?
Yeah, fat fingered
Hi, if it helps, I did some benchmarking for RCA. Although it's incomplete currently, it might give a better picture.
The slowdown exists for torchao 0.13 but not for 0.14. This is even after compilation and in memory bound inference with small batch sizes and activations.
https://github.com/vipulSharma18/Inference-Profiling-and-Optimization-Worklog/blob/main/torchao_float8%2FREADME.md