ao icon indicating copy to clipboard operation
ao copied to clipboard

[float8] Add support for blockwise fp8 quantization scheme used in DeepSeek v3

Open danielvegamyhre opened this issue 11 months ago • 4 comments

DeepSeek v3 uses a blockwise fp8 quantization strategy, where the scaling factor is computed independently for each block, rather than for each tensor/row/etc. The code is available here.

It would be useful for torchao to support this as well, for users wishing to do research or development with this same quantization strategy.

cc @drisspg @vkuzo

danielvegamyhre avatar Jan 22 '25 01:01 danielvegamyhre

Just want to add some of my observations here. I played around abit with block-wise FP8 on my consumer GPU (4070Ti SUPER, sm89). A simple triton kernel does not perform really well, only 1.4x speedup over BF16 (for reference, row-wise FP8 is ~1.9x speedup). With dynamic quant overhead, e2e speedup won't be too attractive. (ofc optimizing for Hopper will be completely different).

Also tried block-wise INT8 (which is the main idea of JetFire). A simple triton kernel performs somewhat ok on consumer GPU (1.9x speedup over BF16, compared to row-wise INT8 is 2.9x speedup - note that INT8 matmul is 4x faster than BF16 on consumer GPUs), but on A100, couldn't get any speedup (speedup < 1). Probably because in the case of block-wise INT8, there is a dtype conversion from INT32 to FP32 when scaling MMA accumulate results, while FP8 does not.

For quantization BLOCK_SIZE_K (number of elements along K dim that share 1 scale value), I think only K<=128 would have a simple and performant implementation, since if BLOCK_SIZE_K is too big, we will use too much shared memory. Tried a few ways around it, such as loading tiles smaller than quantization BLOCK_SIZE_K, but couldn't make it fast.

gau-nernst avatar Jan 22 '25 13:01 gau-nernst

We should also take a look at the new blockwise fp8 gemm added in cutlass 3.7

cc @alexsamardzic

drisspg avatar Jan 22 '25 17:01 drisspg

A PR has been created for this issue: PR #1668

the-tuning-machine avatar Feb 05 '25 15:02 the-tuning-machine

#1763 has been merged w/ fp8 blockwise quant/dequant, fp8 blockwise gemm, and simple linear wrapper for anyone interested in experimenting with this.

danielvegamyhre avatar May 13 '25 01:05 danielvegamyhre

Hi @danielvegamyhre I have added native support in PR#925 to quantize bf16 weights. Later I will added optimized quantized kernel in SGLang and flashFloat, let me know if it helps :

Hi @danielvegamyhre I have added native support in PR#925 to quantize bf16 weights. Later I will added optimized quantized kernel in SGLang and flashFloat, let me know if it helps :

@yiakwy-xpu-ml-framework-team we already have kernels for this in torchao, see https://github.com/pytorch/ao/blob/main/torchao/prototype/blockwise_fp8/blockwise_quantization.py

We need an e2e float8 blockwise training experience using these kernels now :)

danielvegamyhre avatar Jul 01 '25 16:07 danielvegamyhre

https://github.com/pytorch/ao/blob/main/torchao/prototype/blockwise_fp8/blockwise_quantization.py

That is nice!.

Note this line is wrong, cause, K is not alwasy divied by 128 : https://github.com/pytorch/ao/blob/19237e306ceea78b74b67ef315089b7ee38ba55f/torchao/prototype/blockwise_fp8/blockwise_quantization.py#L204

Later I will udpate my algorithm directly to torchao, so that my people can have refrence to this work.

Just want to add some of my observations here. I played around abit with block-wise FP8 on my consumer GPU (4070Ti SUPER, sm89). A simple triton kernel does not perform really well, only 1.4x speedup over BF16 (for reference, row-wise FP8 is ~1.9x speedup). With dynamic quant overhead, e2e speedup won't be too attractive. (ofc optimizing for Hopper will be completely different).

Also tried block-wise INT8 (which is the main idea of JetFire). A simple triton kernel performs somewhat ok on consumer GPU (1.9x speedup over BF16, compared to row-wise INT8 is 2.9x speedup - note that INT8 matmul is 4x faster than BF16 on consumer GPUs), but on A100, couldn't get any speedup (speedup < 1). Probably because in the case of block-wise INT8, there is a dtype conversion from INT32 to FP32 when scaling MMA accumulate results, while FP8 does not.

For quantization BLOCK_SIZE_K (number of elements along K dim that share 1 scale value), I think only K<=128 would have a simple and performant implementation, since if BLOCK_SIZE_K is too big, we will use too much shared memory. Tried a few ways around it, such as loading tiles smaller than quantization BLOCK_SIZE_K, but couldn't make it fast.

Hey, By any chance can you send me the simple triton kernel for fp8 gemm that you mentioned?

Manan17 avatar Jul 25 '25 21:07 Manan17