text-generation-inference icon indicating copy to clipboard operation
text-generation-inference copied to clipboard

Add FP8 KVCache support

Open mht-sharma opened this issue 1 year ago • 5 comments

What does this PR do?

This PR introduces support for FP8 KV Cache in Text Generation Inference (TGI), significantly enhancing performance and memory efficiency on both Nvidia and AMD GPUs. By quantizing the KV cache to 8-bit floating point (FP8) formats, we can greatly reduce the memory footprint, leading to faster and more scalable text generation.

Hardware Compatibility:

  • Nvidia GPUs: Supports both FP8E4M3 and FP8E5M2 (TODO: Need VLLM update).
  • AMD GPUs: Supports FP8E4M3.

Example Usage:

text-generation-launcher --model-id <model_id> --kv-cache-dtype fp8/fp8_e5m2

KV cache scaling factors should be included in the FP16 checkpoint for E4M3 format to maintain accuracy. Default scaling factor is set to 1.0 if not provided, which may lead to accuracy loss.

Checkpoint Structure for KV Scales:

The FP8 KV cache scaling factors are specified through the .kv_scale parameter in the attention module

model.layers.0.self_attn.kv_scale                < F32
model.layers.1.self_attn.kv_scale                < F32

This follows a structure proposed in vllm - https://docs.vllm.ai/en/stable/quantization/fp8.html#fp8-checkpoint-structure-explanation

When providing .kv_scale in model, the config should specify proper kv_cache_torch_dtype used to generate scales (float8_e4m3fn or float8_e4m3fnuz).

Currently, users need to extract the KV scales from FP8 checkpoint and add to the FP16 model. A helper script is provided in the PR for the same.

Sample Models with KV scales: Models with FP8 KV Cache

Todos:

  • [x] Documentation
  • [ ] Tests
  • [ ] Update VLLM for CUDA to support E5M2. @Narsil could you help with this!
  • [ ] Only supports LLAMA, will update same for other models in this or other PRs

mht-sharma avatar Jun 06 '24 07:06 mht-sharma

Happy to help with the rebase btw.

Narsil avatar Jun 06 '24 16:06 Narsil

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.

  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

mht-sharma avatar Jun 14 '24 12:06 mht-sharma

Thanks for the review @Narsil @fxmarty I will rebase and address the comments.

Regarding the format for loading the FP8 scales:

VLLM offers two methods:

  • quantization-param-path: This uses a JSON file (kv_cache_scales.json) containing per-tensor scaling factors for each layer. Example can be found here. This file is generated using the Nvidia AMMO quantizer available here.
  • Direct loading from checkpoints: This method has been introduced in one of the recent PRs and is located here.

VLLM intends to deprecate the quantization-param-path method soon, favoring the use of checkpoints for loading scales. Therefore, I would update our approach to also load scales using checkpoints.

Removed the quantization-param-path altogether: This method is already deprecated in VLLM, based on discussions here: https://github.com/vllm-project/vllm/issues/4532

mht-sharma avatar Jun 24 '24 13:06 mht-sharma

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Also I don't see any core logic to actually handle the fp8, are the kernels ready? Is it possible to test/add tests ?

The core logic for handling FP8 is managed by the Paged Attention kernel in VLLM, with the necessary kernel tests. If you have any specific tests in mind, we can discuss them. VLLM includes tests that compare the output with the expected FP8 output, as seen https://github.com/comaniac/vllm/blob/main/tests/models/test_fp8.py. We can add a similar test if required.

mht-sharma avatar Jun 24 '24 15:06 mht-sharma

Closing this as we added support for FP8 kv cache support in https://github.com/huggingface/text-generation-inference/pull/2603.

More support is coming (for pre-scaled kv-cache fp8)

Narsil avatar Oct 08 '24 09:10 Narsil