TensorRT-LLM
TensorRT-LLM copied to clipboard
`--use_fp8_context_fmha` is broken with Llama models
System Info
- 1x H100 SXM
- tensorrt-llm 0.12.0.dev2024080600
Who can help?
@Tracin
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
1a. Build FP8 quantized checkpoint using LLaMAForCausalLM.quantize
API, by passing QuantConfig(quant_algo=QuantAlgo.FP8)
as an argument
1b. Same but with QuantConfig(quant_algo=QuantAlgo.FP8, kv_cache_quant_algo=QuantAlgo.FP8)
(are there any difference expected in practice?)
2. Try building engine with --use_fp8_context_fmha
Expected behavior
Engine is successfully built and performs better
actual behavior
for 1a engine build fails with:
terminate called after throwing an instance of 'tensorrt_llm::common::TllmException'
what(): [TensorRT-LLM][ERROR] Assertion failed: getIdx() should not be used with entry 17
(/home/jenkins/agent/workspace/LLM/main/L0_MergeRequest/tensorrt_llm/cpp/tensorrt_llm/plugins/gptAttentionPlugin/gptAttentionPlugin.cpp:131)
1 0x7b2f125ba395 /usr/local/lib/python3.10/dist-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so(+0x51395) [0x7b2f125ba395]
2 0x7b2f1260481b tensorrt_llm::plugins::GPTAttentionPlugin::getIdx(tensorrt_llm::plugins::GPTAttentionPlugin::IdxEntry const&) const + 107
3 0x7b2f12604ad4 tensorrt_llm::plugins::GPTAttentionPlugin::supportsFormatCombination(int, nvinfer1::PluginTensorDesc const*, int, int) + 660
4 0x7b338e7ddae4 /usr/local/lib/python3.10/dist-packages/tensorrt_libs/libnvinfer.so.10(+0xb8aae4) [0x7b338e7ddae4]
for 1b it fails fails with:
RuntimeError: Paged Context FMHA doesn't work with int8 kv cache currently.
which seems to be set in config.json
:
...
"producer": {
"name": "modelopt",
"version": "0.15.1"
},
...
"quantization": {
"quant_algo": "FP8",
"kv_cache_quant_algo": "INT8"
},
After manually fixing kv_cache_quant_algo
to FP8
it started working, but not sure if it's correct fix.
additional notes
Is LlamaModel.quantize
recommended way to build checkpoints? Should I use something else instead?