vllm
vllm copied to clipboard
Enable scaled FP8 (e4m3fn) KV cache on ROCm (AMD GPU)
As part of a series of FP8 development in vLLM, we address an OCP format (nVIDIA compatible) FP8 KV cache in this pull request. We elaborated upon previous #2279, but made following change, enhancement and extensions:
- Using OCP FP8 data type, E4M3 recommended for inference (as Float8E4M3FN in MLIR, float8_e4m3fn in PyTorch)
- Using scaled FP8 KV cache, to mitigate quantization loss (scaling factors are aquired from AMD quantizer, AMMO, etc.)
- Enabled on AMD MI3xx GPUs, MI300x (192GB HBM) in particular (less performant on older silicons without FP8 HW)
Design reference:
Scope:
- Used in conjunction with Quantizer's output: KV cache scaling factors. For this phase, not include the
activationandweightssections from the JSON schema proposed in #2461. Quantizer's output may need to be formatted to that schema based JSON file for vLLM code to consider, an utility script (3rdparty/quantizer/extract_scales.py) is provided for JSON generation from AMMO's output. - Quantizer supported: AMD Quantizer, nVIDIA AMMO. an utility script (
3rdparty/quantizer/quantize.py) is provided for using AMMO to quantize HF model to FP8 with FP8 KV cache s.t. KV cache scaling factors will be generated (over a calibartion dataset, which you can change to your domain of interests), details in3rdparty/README.md. - Only the common OCP format used for FP8 inference and model forward/eval
e4m3fnis enabled, this comes with HW support (so performant) on AMD MI3xx GPUs. Same design is still functional but less performant on earlier AMD GPUs, current design does not cover CUDA device. - Model: Llama first, others will be added later after approval.
- FP8 KV cache only, with scaling, FP8 compute coming next.
Scaling semantics:
- In concept and this design, we have following definition:
scaling_factor = AbsMax(input_tensor_fp16_or_bfloat16_or_fp32) / (OCP_E4M3_MAXNORM = 448.0) - This semantics is used by AMD quantizer, and AMMO upon observation.
scaled_to_fp8_quant: fp8_tensor = fp8_quant(higher_precision_tensor / scaling_factor)scaled_fr_fp8_dequant: higher_precision_tensor = fp8_dequant(fp8_tensor) * scaling factor- per tensor scaling as current; per channel, etc. scaling in the future
Usage:
To start, please refer to:
./tests/fp8_kv/README.md
./3rdparty/README.md
Two example JSON files are provided under:
./tests/fp8_kv/
If you run vLLM with kv_cache_dtype="fp8" but not provide JSON file containing scaling factors, then no scaling will be applied towards FP8 (e4m3fn) quantization, which may lead to less accurate results.
manual execution:
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="/data/models/llama-2-70b-chat-hf", kv_cache_dtype="fp8", scales_path="./tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json")
prompt = "London is the capital of"
out = llm.generate(prompt, sampling_params)[0].outputs[0].text
print(out)
Performance:
We observed 20~30% performance increases from FP16 baseline by just turning KV cache to FP8 (e4m3fn), even on the 70B model served by a single MI300X.
WizardCoder-34b score, dataset: HumanEval-Python-EN on 1-GPU MI300X
| KV cache type and config | pass@1 | pass@10 | pass@100 |
|---|---|---|---|
| FP16 (T=0.8) | 30.63% | 79.82% | 95.73% |
| FP8_scaled (T=0.8) | 30.55% | 78.76% | 94.51% |
| FP16 (BEAM, T=0.0) | 40.76% | 55.04% | 58.54% |
| FP8_scaled (BEAM, T=0.0) | 43.38% | 56.63% | 58.54% |
Contributors:
@HaiShaw, @AdrianAbeyta, @gshtras, @mawong-amd, @Alexei-V-Ivanov-AMD
@AdrianAbeyta maybe it would be nicer UX if the kv_cache_scales.json were packaged with the model and the path to it was referenced in the model's config.json? Then it could be automatically used if needed
@AdrianAbeyta maybe it would be nicer UX if the
kv_cache_scales.jsonwere packaged with the model and the path to it was referenced in the model'sconfig.json? Then it could be automatically used if needed
@michael, It is beyond UX, current model and its config.json are all from HF, we don't intend to intervene or alter those (HF hasn't shown FP8 quantization and it's use for inference today). We decide to provide a self contained, HF independent and vendor friendly design (working with both AMD and nVIDIA quantizers), and to bring much needed FP8 features onto the latest HW.
@zhaoyang-star hi! do you think you could help review this PR?
@zhaoyang-star hi! do you think you could help review this PR?
Sorry for the late reply. Thanks for your greate work! I will try my best to review it, though I don't have much bandwidth these days.
Thanks for your great work! I am trying my best to understand the pr. Comments has been left. There are still several questions:
- Will it automatically disable fp8 kv cache when gpu does not support fp8? Or error info pops in this case? [Ans: this PR only adds functionality to AMD GPUs, it will use HW FP8 acceleration on MI300 or later, use SW FP8 emulation on earlier AMD GPUs, details in
hip_float8_impl.h]- As fp8_e4m3fn only support llama family model, will it occur error when trying to use fp8_e4m3fn under other models? [Ans: the plan is to get llama included in this PR for the 1st review of a series changes, we will enable other models under the same design, in straight forward manner. Currently with this PR when we run other model, there won't be actual scaling factors loaded, it would just function as usual (scaling factor equals 1.0)]
- It is better to share the benchmark/accuracy results on AMD GPU using fp8_e4m3fn. It is very useful for us. [Ans: we observed up to 25~30% (relative to fp16) throughput increases with fp8_e4m3fn kv cache enabled on MI300x, we used table of WizardCoder-34b score to represent accuracies side of comparisons, same as your #2279]
I will run this pr locally and then more comments may be added. [Ans: sounds good. We had fixed 2 rounds of merge conflicts since 3/8, and would do a 3rd time for current conflicts. FYI, it won't change any existing behavior on NV platform]
@zhaoyang-star , had answers inline above. Thanks for your review!
Concerns about the API naming (I would like to hear @zhuohan123 and @WoosukKwon's thought on this):
--quantization-param-path: should it be named more narrowly, something like--scaling-factor-per-layer-json? Even with this I find it difficult for users without understanding the JSON format.- Naming of
kv_scalein the paged attention kernel. Should it be calledscaling_factororfp8_scaling?Overall I think once the remaining comments are settled and these two API design questions resolved, this PR is in good shape to merge.
For the future, it would be a lot easier to review if the renaming part is isolated out to a separate PR.
--quantization-param-path was called scales-path before, we modified to such in last round review to address comment and also a bit forward looking (that we expect it cover more than kv cache relevant scaling only). --scaling-factor-per-layer-json is accurate to current stand of this PR, or --layered-scaling-factor-path seems good too.
If we all agreed on a name, we can go ahead make a change.
Seems like this PR fails the main branch now .
Can the files hip_float8.h and hip_float8_impl.h be part of some AMD SDK going forward? They shouldn't be part of vLLM :)
Can the files
hip_float8.handhip_float8_impl.hbe part of some AMD SDK going forward? They shouldn't be part of vLLM :)
That was the plan, once we have common fp8 header released.