TensorRT-LLM icon indicating copy to clipboard operation
TensorRT-LLM copied to clipboard

`-1` token id with Mixtral FP8 and tensorrt_llm 0.11.0

Open v-dicicco opened this issue 1 year ago • 2 comments

  • CPU architecture: x86_64
  • GPU: NVIDIA H100
  • Libraries
    • TensorRT-LLM: v0.11.0
    • TensorRT: 10.1.0
    • Modelopt: 0.13.1
    • CUDA: 12.3
  • NVIDIA driver version: 535.129.03

Hello, I'm experiencing a weird behaviour with Mixtral FP8 (weights and KV Cache) when doing inference with batch size 16 and TensorRT LLM 0.11.0: occasionally the generation starts with the token id -1, producing a sequence of tokens that can't be decoded.

Unfortunately I can't share the prompts nor the finetuned model, but here are some details that hopefully can help understand the root cause:

  1. The issue happen only on specific batches and only if the batch is full. I've tried to isolate potential "broken" prompts inside the batch, but if I do inference sample by sample (batch size = 1) the generation is correct. Even If I try to do inference with combinations of these prompts (up to 15) the generation is correct. As soon as I recreate the original full batch (even shuffling the order) some generation starts to have the first token id as -1.
  2. The output ids tensor of a failed generation starts with the right prompt token ids, then a -1 (the first generated token) then followed by many "garbage" tokens, all in a valid range (like if the first value of the outputs list in run.py is -1).
  3. The issue persists even if we do again calibration and compilation, failing on the same batch.
  4. I started to have this problem after the update to 0.11.0
  5. If I disable use_fp8_context_fmha, noticing the issue could be related to the context phase, the problem disappear (at least for now).

Conversion command:

python quantize.py --model_dir <model_path> \
                                   --dtype float16 \
                                   --qformat fp8 \
                                   --kv_cache_dtype fp8 \
                                   --output_dir <output_path> \
                                   --calib_size 512 \
                                   --tp_size 1 \
                                   --calib_dataset <dataset_path>

Compilation command:

trtllm-build --checkpoint_dir=<output_path> \
                    --max_beam_width=1 \
                    --max_seq_len=32768 \
                    --max_num_tokens=26230 \ 
                    --max_batch_size=16 \
                    --context_fmha=enable \
                    --use_custom_all_reduce=disable \
                    --use_fp8_context_fmha=enable \
                    --output_dir=<output>

Some questions:

Thanks for your hard work!

v-dicicco avatar Jul 31 '24 23:07 v-dicicco

This problem occurs when I set the temperature very low(like 0.0001), don't know the exact cause.

vonchenplus avatar Aug 01 '24 00:08 vonchenplus

@vonchenplus thanks for the answer. In my case I'm not using a very low temperature, it is 0.1

v-dicicco avatar Aug 07 '24 21:08 v-dicicco

Looks like it is related to use_fp8_context_fmha, @vonchenplus Did you also enable FP8 context fmha?

Tracin avatar Aug 21 '24 07:08 Tracin

@v-dicicco please disable fp8 context fmha for now. the accuracy issues might happen on certain models only due to the fp8 computation. we are adding qkv tensor scales, which might be helpful for this. we will let you know when it is available.

Also, it would be great if you can give it a try with llama 7b (it should work, otherwise there is a bug).

PerkzZheng avatar Aug 21 '24 07:08 PerkzZheng

@PerkzZheng @Tracin thanks! I can confirm that disabling use_fp8_context_fmha solves the issue, If we will do tests on llama 7b I will update here.

Regarding the decrease in quality of some models, do you know if Mixtral 8x7B is expected to not lose significant accuracy with FP8? we are seeing a significant quality drop in our finetuned models when used with FP8 w.r.t int8 weight-only and fp16, even carefully crafting the calibration dataset, using the official examples/quantization/quantize.py and using the same compilation flags among the tests (and trt-llm 0.11.0).

v-dicicco avatar Aug 21 '24 16:08 v-dicicco

I faced this problem with enabled cache reuse with llama like model

{"asctime":"2024-09-12 12:33:46,923901","level":"WARNING","module":"Feed ( src/lib/pipeline/src/stream_tokenizer.cpp:36 ) ","task_id":"7F276083E300","thread_id":"0x00007F279E3FF400","text":"token is out of range: -1"}
{"asctime":"2024-09-12 12:33:46,923913","level":"WARNING","module":"Feed ( src/lib/pipeline/src/stream_tokenizer.cpp:36 ) ","task_id":"7F2786264900","thread_id":"0x00007F2796BFF400","text":"token is out of range: -1"}
{"asctime":"2024-09-12 12:33:46,923905","level":"WARNING","module":"Feed ( src/lib/pipeline/src/stream_tokenizer.cpp:36 ) ","task_id":"7F278C05E600","thread_id":"0x00007F27749FE400","text":"token is out of range: -1"}
{"asctime":"2024-09-12 12:33:46,923968","level":"WARNING","module":"Feed ( src/lib/pipeline/src/stream_tokenizer.cpp:36 ) ","task_id":"7F278E86C000","thread_id":"0x00007F27453FF400","text":"token is out of range: -1"}

Model build command is

{
    "producer": {
        "name": "modelopt",
        "version": "0.13.1"
    },
    "architecture": "LlamaForCausalLM",
    "dtype": "float16",
    "logits_dtype": "float16",
    "num_hidden_layers": 60,
    "num_attention_heads": 64,
    "num_key_value_heads": 16,
    "hidden_size": 6656,
    "norm_epsilon": 1e-06,
    "vocab_size": 42064,
    "max_position_embeddings": 8192,
    "hidden_act": "silu",
    "use_parallel_embedding": true,
    "embedding_sharding_dim": 0,
    "quantization": {
        "quant_algo": "FP8",
        "kv_cache_quant_algo": "FP8"
    },
    "mapping": {
        "world_size": 1,
        "tp_size": 1,
        "pp_size": 1
    },
    "head_size": 128,
    "intermediate_size": 17920,
    "position_embedding_type": "rope_gpt_neox",
    "share_embedding_table": false,
    "residual_mlp": false,
    "bias": false,
    "rotary_pct": 1.0,
    "rank": 0,
    "decoder": "llama",
    "rmsnorm": true,
    "lm_head_bias": false,
    "rotary_base": 100000,
    "rotary_scaling": null
}
trtllm-build  --checkpoint_dir 1-gpu-tmp/ --output_dir 1-gpu --max_input_len 31744 --max_seq_len 32768 --max_num_tokens 32768 --max_batch_size 256 --gemm_plugin fp8 --use_paged_context_fmha enable

If I build model in other way and disable cache reuse then all work fine.

trtllm-build  --checkpoint_dir 1-gpu-tmp/ --output_dir 1-gpu --max_input_len 31744 --max_seq_len 32768 --max_num_tokens 32768 --max_batch_size 256 --gemm_plugin fp8

I think this is due to https://github.com/NVIDIA/TensorRT-LLM/issues/2217

akhoroshev avatar Sep 12 '24 09:09 akhoroshev