[float8] Add support for blockwise fp8 quantization scheme used in DeepSeek v3
DeepSeek v3 uses a blockwise fp8 quantization strategy, where the scaling factor is computed independently for each block, rather than for each tensor/row/etc. The code is available here.
It would be useful for torchao to support this as well, for users wishing to do research or development with this same quantization strategy.
cc @drisspg @vkuzo
Just want to add some of my observations here. I played around abit with block-wise FP8 on my consumer GPU (4070Ti SUPER, sm89). A simple triton kernel does not perform really well, only 1.4x speedup over BF16 (for reference, row-wise FP8 is ~1.9x speedup). With dynamic quant overhead, e2e speedup won't be too attractive. (ofc optimizing for Hopper will be completely different).
Also tried block-wise INT8 (which is the main idea of JetFire). A simple triton kernel performs somewhat ok on consumer GPU (1.9x speedup over BF16, compared to row-wise INT8 is 2.9x speedup - note that INT8 matmul is 4x faster than BF16 on consumer GPUs), but on A100, couldn't get any speedup (speedup < 1). Probably because in the case of block-wise INT8, there is a dtype conversion from INT32 to FP32 when scaling MMA accumulate results, while FP8 does not.
For quantization BLOCK_SIZE_K (number of elements along K dim that share 1 scale value), I think only K<=128 would have a simple and performant implementation, since if BLOCK_SIZE_K is too big, we will use too much shared memory. Tried a few ways around it, such as loading tiles smaller than quantization BLOCK_SIZE_K, but couldn't make it fast.
We should also take a look at the new blockwise fp8 gemm added in cutlass 3.7
cc @alexsamardzic
A PR has been created for this issue: PR #1668
#1763 has been merged w/ fp8 blockwise quant/dequant, fp8 blockwise gemm, and simple linear wrapper for anyone interested in experimenting with this.
Hi @danielvegamyhre I have added native support in PR#925 to quantize bf16 weights. Later I will added optimized quantized kernel in SGLang and flashFloat, let me know if it helps :
Hi @danielvegamyhre I have added native support in PR#925 to quantize bf16 weights. Later I will added optimized quantized kernel in SGLang and flashFloat, let me know if it helps :
@yiakwy-xpu-ml-framework-team we already have kernels for this in torchao, see https://github.com/pytorch/ao/blob/main/torchao/prototype/blockwise_fp8/blockwise_quantization.py
We need an e2e float8 blockwise training experience using these kernels now :)
https://github.com/pytorch/ao/blob/main/torchao/prototype/blockwise_fp8/blockwise_quantization.py
That is nice!.
Note this line is wrong, cause, K is not alwasy divied by 128 : https://github.com/pytorch/ao/blob/19237e306ceea78b74b67ef315089b7ee38ba55f/torchao/prototype/blockwise_fp8/blockwise_quantization.py#L204
Later I will udpate my algorithm directly to torchao, so that my people can have refrence to this work.
Just want to add some of my observations here. I played around abit with block-wise FP8 on my consumer GPU (4070Ti SUPER, sm89). A simple triton kernel does not perform really well, only 1.4x speedup over BF16 (for reference, row-wise FP8 is ~1.9x speedup). With dynamic quant overhead, e2e speedup won't be too attractive. (ofc optimizing for Hopper will be completely different).
Also tried block-wise INT8 (which is the main idea of JetFire). A simple triton kernel performs somewhat ok on consumer GPU (1.9x speedup over BF16, compared to row-wise INT8 is 2.9x speedup - note that INT8 matmul is 4x faster than BF16 on consumer GPUs), but on A100, couldn't get any speedup (speedup < 1). Probably because in the case of block-wise INT8, there is a dtype conversion from INT32 to FP32 when scaling MMA accumulate results, while FP8 does not.
For quantization
BLOCK_SIZE_K(number of elements along K dim that share 1 scale value), I think only K<=128 would have a simple and performant implementation, since ifBLOCK_SIZE_Kis too big, we will use too much shared memory. Tried a few ways around it, such as loading tiles smaller than quantizationBLOCK_SIZE_K, but couldn't make it fast.
Hey, By any chance can you send me the simple triton kernel for fp8 gemm that you mentioned?