sglang icon indicating copy to clipboard operation
sglang copied to clipboard

Add fp8 gemm kernel for CPU in sgl-kernel and add gemm UT

Open chunyuan-w opened this issue 7 months ago • 4 comments

Motivation

This PR is a follow-up on https://github.com/sgl-project/sglang/issues/2807 and https://github.com/sgl-project/sglang/pull/5150 to add fp8 gemm kernel for CPU. The bf16 and int8 gemm kernel is already added in https://github.com/sgl-project/sglang/pull/5150.

This PR also adds UTs for bf16, int8 and fp8 gemm kernels for CPU.

Modifications

The main change is the C++ kernels for fp8 gemm on CPU: sgl-kernel/csrc/cpu/gemm_fp8.cpp The UTs for gemm OPs on CPU: test/srt/cpu/test_gemm.py

chunyuan-w avatar May 12 '25 07:05 chunyuan-w

@mingfeima could you please review this PR?

chunyuan-w avatar May 12 '25 08:05 chunyuan-w

@chunyuan-w LGTM! Let's wait @blzheng finished the CMakeList.txt change and rebase after it.

mingfeima avatar May 12 '25 12:05 mingfeima

@chunyuan-w need to fix CI failure if they are true.

mingfeima avatar May 13 '25 01:05 mingfeima

@chunyuan-w need to fix CI failure if they are true.

I merged the latest main branch and now the CIs are all green.

chunyuan-w avatar May 13 '25 02:05 chunyuan-w

@chunyuan-w please rebase as https://github.com/sgl-project/sglang/pull/6115 has been landed.

mingfeima avatar May 14 '25 01:05 mingfeima

@chunyuan-w please rebase as #6115 has been landed.

I have merged the latest main. There're new CI failures but seems to be unrelated to the change in this PR. I can rebase the code later to trigger the CI again to see if they still fail.

This one is happening in other PRs as well. image

As for this one, as we haven't yet updated the model code to call the kernel we add in this PR, this PR shouldn't cause model accuracy change. image

chunyuan-w avatar May 14 '25 05:05 chunyuan-w

@chunyuan-w please rebase as #6115 has been landed.

I have merged the latest main. There're new CI failures but seems to be unrelated to the change in this PR. I can rebase the code later to trigger the CI again to see if they still fail.

This one is happening in other PRs as well. image

As for this one, as we haven't yet updated the model code to call the kernel we add in this PR, this PR shouldn't cause model accuracy change. image

Don't worry lets re-run the jobs

Alcanderian avatar May 14 '25 05:05 Alcanderian

Hi @Alcanderian, seems there're still failures in CI unrelated to the change in this PR. Do we need to re-run them? image image

chunyuan-w avatar May 15 '25 01:05 chunyuan-w

@mingfeima @chunyuan-w , hello, one question for this fp8 gemm on AMX CPU: will this use AMX_INT8 to compute fp8?

blossomin avatar Jul 09 '25 09:07 blossomin

@mingfeima @chunyuan-w , hello, one question for this fp8 gemm on AMX CPU: will this use AMX_INT8 to compute fp8?

The activation is BF16. We dequantized weight from FP8 to BF16 to do the computation inside the kernel.

chunyuan-w avatar Jul 11 '25 06:07 chunyuan-w

@mingfeima @chunyuan-w , hello, one question for this fp8 gemm on AMX CPU: will this use AMX_INT8 to compute fp8?

The activation is BF16. We dequantized weight from FP8 to BF16 to do the computation inside the kernel.

do you have any plan to support AMX_INT8 computation for this fp8 version, since this can reduce the data size to half? (correct me if the data is moved to register before dequantization). and this might help the decoding speed due to the slow host memory bandwidth.

and Huawei CloudMatrix uses INT8 to accelerate inference, see fig 10 in this report: https://arxiv.org/pdf/2506.12708v1

blossomin avatar Jul 11 '25 07:07 blossomin

@blossomin ascend also does not support fp8, they re-quantize the model to int8. On the CPU path, we also support int8 with w8a8 per channel recipe, it is the same story as ascend.

but some customers just clearly point out that they need fp8 data as well (mostly compare the accuracy), so we implemented an emulated approach here which converts fp8 to bf16 for computation with minimal accuracy loss.

in next gen of Xeon (7th gen), it will support AMX-FP8 natively (targeted 2026, codename DMR).

mingfeima avatar Jul 11 '25 08:07 mingfeima