ao icon indicating copy to clipboard operation
ao copied to clipboard

float8 inference weight-only quant should map to a fused kernel or explain why not

Open vkuzo opened this issue 2 months ago • 1 comments

Today, Float8WeightOnlyConfig maps to a reference implementation of weight-only quant, which dequantized the tensor and then runs a high precision gemm: https://github.com/pytorch/ao/blob/ba3ac9f2f6117ba35ff28fbb8811f61ad992dfcf/torchao/quantization/quantize_/workflows/float8/float8_tensor.py#L392

Users have reported confusion about this, we should either clearly explain that no speedup is expected or map to a fast kernel.

vkuzo avatar Nov 04 '25 12:11 vkuzo

Hi, I ran torchao's llama benchmark script with Float8WeightOnlyConfig on RTX 5090, and there's also abnormally high peak memory usage with this config, along with the initially reported slowdown:


20251108230606, tok/s=104.68, tok/s_decode=105.73, ttft=0.0186, mem/s=1571.23 GB/s, peak_mem=16.30 GB, model_size=15.01 GB quant: None, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 

20251108230802, tok/s=160.55, tok/s_decode=169.80, ttft=0.0675, mem/s=1204.97 GB/s, peak_mem= 9.21 GB, model_size= 7.51 GB quant: float8dq-tensor, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8dq-tensor --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 

20251108231236, tok/s=  7.44, tok/s_decode=  7.48, ttft=0.1442, mem/s=  55.91 GB/s, peak_mem=26.00 GB, model_size= 7.51 GB quant: float8wo, sparse: None, mod: Meta-Llama-3.1-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization float8wo --checkpoint_path /root/checkpoints/unsloth/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 

I'm trying to raise a PR now that I've been able to reproduce the behavior with torchao's benchmarking code instead of my own code.

vipulSharma18 avatar Nov 08 '25 23:11 vipulSharma18

this is what we were doing before: https://github.com/pytorch/ao/blob/ba3ac9f2f6117ba35ff28fbb8811f61ad992dfcf/torchao/dtypes/floatx/float8_layout.py#L391

cc @drisspg @jainapurva can you comment on why float8 weight only quant is doing fallback? is this used?

jerryzh168 avatar Dec 05 '25 01:12 jerryzh168

AFAIK there exists goes in inductor that can do prologue fusion and actually get speed ups. For decode size with small activations and large weights this can be faster if compiled.

That being said I have no script or benchmark showing this but my recollection of @eellison work

drisspg avatar Dec 05 '25 04:12 drisspg

@drisspg do you mean prologue fusion ?

eellison avatar Dec 05 '25 14:12 eellison

Yeah, fat fingered

drisspg avatar Dec 05 '25 15:12 drisspg

Hi, if it helps, I did some benchmarking for RCA. Although it's incomplete currently, it might give a better picture.

The slowdown exists for torchao 0.13 but not for 0.14. This is even after compilation and in memory bound inference with small batch sizes and activations.

https://github.com/vipulSharma18/Inference-Profiling-and-Optimization-Worklog/blob/main/torchao_float8%2FREADME.md

vipulSharma18 avatar Dec 05 '25 15:12 vipulSharma18