TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

`--use_fp8_context_fmha` is broken with Llama models

Open ttim opened this issue 6 months ago • 0 comments

System Info

  • 1x H100 SXM
  • tensorrt-llm 0.12.0.dev2024080600

Who can help?

@Tracin

No response

Information

  • [ ] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

1a. Build FP8 quantized checkpoint using LLaMAForCausalLM.quantize API, by passing QuantConfig(quant_algo=QuantAlgo.FP8) as an argument 1b. Same but with QuantConfig(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8) (are there any difference expected in practice?) 2. Try building engine with --use_fp8_context_fmha

Expected behavior

Engine is successfully built and performs better

actual behavior

for 1a engine build fails with:

terminate called after throwing an instance of 'tensorrt_llm::common::TllmException'
  what():  [TensorRT-LLM][ERROR] Assertion failed: getIdx() should not be used with entry 17
 (/home/jenkins/agent/workspace/LLM/main/L0_MergeRequest/tensorrt_llm/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp:131)
1       0x7b2f125ba395 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so(+0x51395) [0x7b2f125ba395]
2       0x7b2f1260481b tensorrt_llm::plugins::GPTAttentionPlugin::getIdx(tensorrt_llm::plugins::GPTAttentionPlugin::IdxEntry const&) const + 107
3       0x7b2f12604ad4 tensorrt_llm::plugins::GPTAttentionPlugin::supportsFormatCombination(int, nvinfer1::PluginTensorDesc const*, int, int) + 660
4       0x7b338e7ddae4 /usr/local/lib/python3.10/dist-packages/tensorrt_libs/libnvinfer.so.10(+0xb8aae4) [0x7b338e7ddae4]

for 1b it fails fails with:

RuntimeError: Paged Context FMHA doesn't work with int8 kv cache currently.

which seems to be set in config.json:

...
"producer": {
        "name": "modelopt",
        "version": "0.15.1"
    },
...
    "quantization": {
        "quant_algo": "FP8",
        "kv_cache_quant_algo": "INT8"
    },

After manually fixing kv_cache_quant_algo to FP8 it started working, but not sure if it's correct fix.

additional notes

Is LlamaModel.quantize recommended way to build checkpoints? Should I use something else instead?

ttim avatar Aug 13 '24 21:08 ttim