[Feature][Hardware][AMD] Enable Scaled FP8 GEMM on ROCm
Enable Scaled FP8 GEMM on ROCm (AMD GPU)
As part of a series of FP8 development in vLLM, this pull request introduces latest acceleration with FP8 computations on newer AMD hardware (MI30x and later).
- Using OCP FP8 inference data type (float8_e4m3fn) at interface and file exchange level, compatible with OCP FP8 quantized model checkpoints.
- Other than PTQ weights, static scaling factors are used for activations (and KV caches), via calibration process by Quark - the AMD quantizer, or AMMO from Nvidia.
- This is ROCm/hipBLASLt based Implementation of scaled FP8 GEMM, adding to previously implemented scaled FP8 KV cache. In case multiple weight matrices are concatenated to a bigger GEMM for performance, this implementation suboptimally uses one vs. multiple scaling factors (to be addressed later). Note - AMD Quark can be configured s.t. certain matrices can be virtually merged prior to their quantization.
- Largely follows the vLLM FP8 RFC: FP8 in vLLM #2461. Specifically, linear and projection layers are covered, while FP8 computation within self attention itself is left for future extension. Current GEMM takes FP8 as input and defaults to output float16/bfloat16. Further optimizations are working in progress, include but not limited to float16/bfloat16 ingress (in kernel conversion), direct FP8 egress to KV cache, etc.
- Note - this feature will not work on MI2xx or older GPUs lacking FP8 MFMA instructions.
Design Reference:
- Note - Quark may add AutoFP8 compatible export, by then we will extend the support accordingly.
- RFC: FP8 Quantization Schema in vLLM update #5802
- RFC: FP8 Quantization Schema in vLLM #3218
- RFC: FP8 in vLLM #2461
Introducing Quark - AMD Quantizer:
Please refer to: AMD Quark landing page
Performance Tuning:
Please refer to: AMD vLLM performance tuning guide
Usage and Examples:
To get started, please refer to:
./examples/fp8/quantizer/README.md
Performance and Accuracy:
With FP8 KV cache together, we observed up to ~50% performance increases on top of FP16 Llama2 baseline, in favor of larger batch sizes and sequence, even on the quantized 70B model served on a single MI300X.
LLM-Q&A, Llama2-70b, dataset: OPEN ORCA on 8 MI300X GPUs (TP=8)
| GEMM Types | Rouge-1 | Rouge-2 | Rouge-L |
|---|---|---|---|
| FP16 | 44.4860 | N/A | 28.6992 |
| FP8 scaled | 44.5001 | 22.0853 | 28.7140 |
Hi @HaiShaw thanks for pushing up this chunk of work. Is there a reason you haven't tried enabling AMD explicitly through the existing "fp8" quantization backend with the current checkpoint format? It seems within your "Fp8Fnuz" method that torch._scaled_mm is actually a valid else case, so could you take advantage of its usage already in the "fp8" backend for an easier starting point?
Hi @HaiShaw thanks for pushing up this chunk of work. Is there a reason you haven't tried enabling AMD explicitly through the existing "fp8" quantization backend with the current checkpoint format? It seems within your "Fp8Fnuz" method that torch._scaled_mm is actually a valid else case, so could you take advantage of its usage already in the "fp8" backend for an easier starting point?
@mgoin thanks for your question! There were couple of reasons that we did not reuse the same backend as exact, other than different internal (HW) format and gemm implementations, not to consider dynamic scaling is a main reason (and we don't prefer to mixup CUDA backend too much in code). In terms of model loading, we started with AMMO support, now AMD Quark, and will be extended to AutoFP8 compatible checkpoint support once RFC #5802 is landed in Quark. Some discrepancy we have is due to the moving nature or completeness of several quantizers that we deal here, arising from different design ideas.
Im going to review this weekend.
@robertgshaw2-neuralmagic thanks for your time. Some are known and we will provide update with AutoFp8 compatible support soon.
Going to summarize comments into a single PR
Going to summarize comments into a single PR
@robertgshaw2-neuralmagic Is there any update?
@robertgshaw2-neuralmagic , thanks for your review, let us get back and address your concerns.
Closing as the source branch has been deleted with no further activity for 2 months