Add FP8 KVCache support
What does this PR do?
This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.
Hardware Compatibility:
- Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2 (TODO: Need VLLM update).
- AMD GPUs: Supports FP8E4M3.
Example Usage:
text-generation-launcher --model-id <model_id> --kv-cache-dtype fp8/fp8_e5m2
KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.
Checkpoint Structure for KV Scales:
The FP8 KV cache scaling factors are specified through the .kv_scale parameter in the attention module
model.layers.0.self_attn.kv_scale < F32
model.layers.1.self_attn.kv_scale < F32
This follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation
When providing .kv_scale in model, the config should specify proper kv_cache_torch_dtype used to generate scales (float8_e4m3fn or float8_e4m3fnuz).
Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.
Sample Models with KV scales: Models with FP8 KV Cache
Todos:
- [x] Documentation
- [ ] Tests
- [ ] Update VLLM for CUDA to support E5M2. @Narsil could you help with this!
- [ ] Only supports LLAMA, will update same for other models in this or other PRs
Happy to help with the rebase btw.
Thanks for the review @Narsil @fxmarty I will rebase and address the comments.
Regarding the format for loading the FP8 scales:
VLLM offers two methods:
-
quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
-
Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.
VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.
Thanks for the review @Narsil @fxmarty I will rebase and address the comments.
Regarding the format for loading the FP8 scales:
VLLM offers two methods:
- quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
- Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.
VLLM intends to deprecate the
quantization-param-pathmethod soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.
Removed the quantization-param-path altogether: This method is already deprecated in VLLM, based on discussions here: https://github.com/vllm-project/vllm/issues/4532
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Also I don't see any core logic to actually handle the fp8, are the kernels ready? Is it possible to test/add tests ?
The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required.
Closing this as we added support for FP8 kv cache support in https://github.com/huggingface/text-generation-inference/pull/2603.
More support is coming (for pre-scaled kv-cache fp8)