llama2-7b bad results for int8-kv-cache + per-channel-int8-weight
System Info
3090 gpu 0.7.1 tensorrt-llm
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [ ] My own task or dataset (give details below)
Reproduction
- get bin_model:
python hf_llama_convert.py -i /root/models/Llama-2-7b/ -o /root/TensorRT -LLM/examples/llama/llama2_7b_w8_int8_kv_cache/ --calibrate-kv-cache -t fp16 - build model:
python build.py --bin_model_dir /root/TensorRT-LLM/examples/llama/llama2_7b_w8_int8_kv_cache/bin_model_dir/ --dtype float16 --use_gpt_attention_plugin float16 --use_gemm_plugin float16 --output_dir /root/TensorRT-LLM/examples/llama/llama2_7b_w8_int8_kv_cache/1-gpu --int8_kv_cache --use_weight_only - test the model:
python mmlu.py --hf_model_dir /root/models/Llama-2-7b/ --engine_dir /root/TensorRT-LLM/examples/llama/llama2_7b_w8_int8_kv_cache/1-gpu/ --test_trt_llm(mmlu.py is provided by TensorRT-LLM here: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/mmlu.py
Unfortunately, step 3 gives me:
Average accuracy 0.297 - math
Average accuracy 0.399 - health
Average accuracy 0.300 - physics
Average accuracy 0.519 - business
Average accuracy 0.361 - biology
Average accuracy 0.274 - chemistry
Average accuracy 0.299 - computer science
Average accuracy 0.349 - economics
Average accuracy 0.317 - engineering
Average accuracy 0.367 - philosophy
Average accuracy 0.513 - other
Average accuracy 0.439 - history
Average accuracy 0.404 - geography
Average accuracy 0.475 - politics
Average accuracy 0.380 - psychology
Average accuracy 0.512 - culture
Average accuracy 0.330 - law
Average accuracy 0.306 - STEM
Average accuracy 0.367 - humanities
Average accuracy 0.409 - social sciences
Average accuracy 0.457 - other (business, health, misc.)
**Average accuracy: 0.384**
the final mmlu accuracy is 38.4, but fp16 accuracy is 45.9, which is very bad. But according to some LLM quantization papers, the acc should not drop so much in this case.
the config.json generated by build.py is something like this:
{
"builder_config": {
"autopp_config": null,
"gather_context_logits": false,
"gather_generation_logits": false,
"hf_modules_to_trtllm_modules": {
"down_proj": "mlp_4h_to_h",
"gate_proj": "mlp_h_to_4h",
"k_proj": "attn_k",
"o_proj": "attn_dense",
"q_proj": "attn_q",
"up_proj": "mlp_gate",
"v_proj": "attn_v"
},
"hidden_act": "silu",
"hidden_size": 4096,
"int8": true,
"lora_target_modules": null,
"max_batch_size": 8,
"max_beam_width": 1,
"max_input_len": 2048,
"max_num_tokens": null,
"max_output_len": 512,
"max_position_embeddings": 2048,
"max_prompt_embedding_table_size": 0,
"mlp_hidden_size": 11008,
"name": "llama",
"num_heads": 32,
"num_kv_heads": 32,
"num_layers": 32,
"parallel_build": false,
"pipeline_parallel": 1,
"precision": "float16",
"quant_mode": 66,
"tensor_parallel": 1,
"trtllm_modules_to_hf_modules": {
"attn_dense": "o_proj",
"attn_k": "k_proj",
"attn_q": "q_proj",
"attn_v": "v_proj",
"mlp_4h_to_h": "down_proj",
"mlp_gate": "up_proj",
"mlp_h_to_4h": "gate_proj"
},
"use_refit": false,
"vocab_size": 32000
},
"plugin_config": {
"attention_qk_half_accumulation": false,
"bert_attention_plugin": false,
"context_fmha_type": 0,
"enable_xqa": false,
"gemm_plugin": "float16",
"gpt_attention_plugin": "float16",
"identity_plugin": false,
"layernorm_plugin": false,
"layernorm_quantization_plugin": false,
"lookup_plugin": false,
"lora_plugin": false,
"multi_block_mode": false,
"nccl_plugin": false,
"paged_kv_cache": false,
"quantize_per_token_plugin": false,
"quantize_tensor_plugin": false,
"remove_input_padding": false,
"rmsnorm_plugin": false,
"rmsnorm_quantization_plugin": false,
"smooth_quant_gemm_plugin": false,
"tokens_per_block": 0,
"use_context_fmha_for_generation": false,
"use_custom_all_reduce": false,
"use_paged_context_fmha": false,
"weight_only_groupwise_quant_matmul_plugin": false,
"weight_only_quant_matmul_plugin": "float16"
}
}
Is there any bug in the quantization code?
Expected behavior
expected mmlu acc does not drop that much
actual behavior
mmlu acc drops so much
additional notes
no more
potential relations with this issue: https://github.com/NVIDIA/TensorRT-LLM/issues/948
https://github.com/NVIDIA/TensorRT-LLM/issues/964
How about test without int8_kv_cache?
@Tracin int8_weight_only, accuracy is good.
So I think it is similar to https://github.com/NVIDIA/TensorRT-LLM/issues/889 You can try FP8 since FP8 kv cache is better than INT8 in accuray.
@Tracin
is int8-weight+int8-kv-cache mmlu accuracy tested on any model in your experiments, since for quantization, this case is the official example of tensorrt-llm?
@Tracin
is int8-weight+int8-kv-cache mmlu accuracy tested on any model in your experiments, since for quantization, this case is the official example of tensorrt-llm?
Yes, We tested int8-kv and int8 weight-only separately with LLAMA1-7b and MMLU score is similar to FP16. According to your test, LLAMA2-7b + int8kv has bad accuracy, right? I will check it.
@Tracin llama2-7b int8 weight + int8 kv-cache, bad accuracy int8 weight-only, good accuracy
Thanks! I think we can remove int8 weight-only for better debug. And you mentioned about your code about quantization, did you use the same kv_cache_scaling_factors?
@Tracin k v seperate scales(per-tensor,static),acc is good. k v merged scales, I will test this case later today.
@Tracin why does tensorrt-LLM have to merge qkv?
@Tracin why does tensorrt-LLM have to merge qkv?
Launching a larger gemm can be more efficient than launching three small kernels. BTW, in smoothquant impletmented by TRTLLM, we use sperated (three) scales even you choose per-tensor mode for QKV.
@Tracin k v seperate scales(per-tensor,static),acc is good. k v merged scales, I will test this case later today.
k v seperate scales(per-tensor,static),acc is fine. k v merged scales, acc drops about 1.5% ~ 2 % @Tracin
Have you done any experiments on llama2-7b int8-weight+int8-kv-cache?
@Tracin is int8-weight+int8-kv-cache mmlu accuracy tested on any model in your experiments, since for quantization, this case is the official example of tensorrt-llm?
Yes, We tested int8-kv and int8 weight-only separately with LLAMA1-7b and MMLU score is similar to FP16. According to your test, LLAMA2-7b + int8kv has bad accuracy, right? I will check it.
@Tracin why does tensorrt-LLM have to merge qkv?
Launching a larger gemm can be more efficient than launching three small kernels. BTW, in smoothquant impletmented by TRTLLM, we use sperated (three) scales even you choose per-tensor mode for QKV.
I don't quite understand
"BTW, in smoothquant impletmented by TRTLLM, we use sperated (three) scales even you choose per-tensor mode for QKV"
- do you mean in smoothquant w8a8 mode, qkv linear layer output quantization scales are seperated? But I also test this mode, llama2-7b smoothquant w8a8 accuracy is also very bad.
- in smoothquant w8a8, kv-cache is also int8?
@Tracin
@brisker I mean when using per-tensor weight quantization mode for SQ, qkv has three different weight scales. What is the quantization mode of this bad w8a8 accuracy? We do have few bugs under some conditions. SQ and kv-cache are orthogonal.
@Tracin just use the official example of smoothquant like this:
python hf_llama_convert.py -i /root/models/Llama-2-7b/ -o ./smooth_llama2_7b_alpha_0.5/sq0.5/ -sq 0.5 --tensor-parallelism 1 --storage-type fp16python build.py --bin_model_dir /root/TensorRT-LLM/examples/llama/smooth_llama2_7b_alpha_0.5/sq0.5/1-gpu/ --use_gpt_attention_plugin float16 --remove_input_padding --enable_context_fmha --use_smooth_quant --per_token --per_channel --output_dir ./smooth_llama2_7b_alpha_0.5/smoothquant_w8a8_1_gpu/python mmlu.py --hf_model_dir /root/models/Llama-2-7b/ --engine_dir /root//TensorRT-LLM/examples/llama/smooth_llama2_7b_alpha_0.5/smoothquant_w8a8_1_gpu/ --test_trt_llm
the code above , which is smoothquant mmlu test, gives me 37.7 accuracy, and fp16 accuracy is 45.9
so until now , smoothquant w8a8 and int8-kv-cache both seem to have bugs, with bad accuracy.
Have you confirmed any bugs?
Thanks for your reply, I got it very clear. I will reproduce and fix it ASAP.
@Tracin I have tested int8-kv-cache and smoothquant w8a8 respectively on Llama-1-7b, both of them got good accuracy( close to fp16 accuracy, about 35.5 on MMLU), just like what you have done before. So just blame on Llama-2-7b.
Just regard this as a cross-check.
@Tracin Is the bug fixed?
@Tracin
get bin_model: python hf_llama_convert.py -i /root/models/Llama-2-7b/ -o /root/TensorRT -LLM/examples/llama/llama2_7b_w8_int8_kv_cache/ --calibrate-kv-cache -t fp16
I use the bin file generated by the command above to build a weight-only-quantize trt-engine, like this:
python build.py --bin_model_dir /root/TensorRT-LLM/examples/llama/llama2_7b_w8_int8_kv_cache/ --dtype float16 --use_gpt_attention_plugin float16 --use_gemm_plugin float16 --output_dir /root/TensorRT-LLM/examples/llama/llama2_7b_w8_int8_kv_cache/1-gpu-weight-only/ --use_weight_only
but the mmlu-test acc is also bad.
But if directly build weight-only-quantize trt-engine like this:
python build.py --model_dir /root/models/Llama-2-7b/ --dtype float16 --remove_input_padding --use_gpt_attention_plugin float16 --enable_context_fmha --use_gemm_plugin float16 --use_weight_only --output_dir /root/TensorRT-LLM/examples/llama/llama2_7b_weight_only/, then the acc is good.
So this is so weird..... The only difference between the two above seems to be
https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/examples/llama/build.py#L690 and https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/examples/llama/build.py#L708
@kaiyux @Shixiaowei02 I noticed that Llama2-70B INT8-SmoothQuant acc drops only about 2% , as described here in the official docs.
Given the discussions in this issue about Llama2-7B, the acc drop is about 8% in my experiments, which is not reasonable. So in the experiments described here, what tensorrt-llm version and experiment condition are you using?
@brisker Hi, since you can reproduce good accuracy on LLAMA-7b and bad accuracy on LLAMA2-7b using SQ and INT8KV respectively, it is clear that different model parameters cause the difference, so there is no actual bugs, right? You can use AMMO to see if they can produce better accuracy. As for weight-only issue, I notice you have different build options between bin_model and HF_model, could you make it aligned? And the last problem, LLAMA2-70B SQ has 2% drop, did you test on LLAMA2-70B? If so, what if --per_channel and -per_token used?
@Tracin How to build TensorRT engine,using the files created by ammo-w8-a8-smoothquant? I can not see any docs.
@Tracin for weight-only issue, you mentioned "make the build option aligned", which option are you refering to?
@Tracin
Many LLM-quantization papers(for example, this paper have stated that LLama2-7b-w8a8-smoothquant accuracy is close to fp16 accuracy on MMLU (including myself have done some experiments in my own codes, the acc is also good)
so I do not think "there is no actual bugs" that you mentioned is convincing. Besides, have you tried to reproduce the accuracy drop by yourself? @byshiue
I mean if you want to test accuracy and compare to papers, please use --per_channel --per_token
@Tracin You can check the comments in this issue I have already wrote, I have already used --per_channel --per_token
@brisker Thanks for your consistent efforts. I did not reproduce bad accuracy on INT8-kv or weight-only, they both produce score of 0.460.
As for SQ problem, there is a bug, you can fix it manually and I will push a MR later.
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/quantization/quantize.py#L165
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/quantization/quantize.py#L219
Please add an argument of eps=config.norm_epsilon, for both of lines.
Now SQ on LLAMA2-7b can produce score of 0.454.
@Tracin thanks for the effort. You mentioned that bad acc on int8-kv is not reproduced. Can you share your tensorrt-llm version and running commands?
@Tracin thanks for the effort. You mentioned that bad acc on int8-kv is not reproduced. Can you share your tensorrt-llm version and running commands?
Sure, please use the latest main branch.
python convert_checkpoint.py --model_dir llama-v2-7b-hf/ -t fp16 --output_dir llama2-7b-int8kv --int8_kv_cache
trtllm-build --checkpoint_dir llama2-7b-int8kv/ --max_batch_size 32 --max_input_len 2048 --max_output_len 512 --gpt_attention_plugin float16 --output_dir llama2-7b-int8kv-engine --strongly_typed
@Tracin
- My TensorRT-LLM version is 0.7.1, and I followed the modifications you mentioned below, but still get 37.6 for w8a8 smoothquant acc on mmlu. So there are some other bugs on 0.7.1-version TensorRT-LLM? Is the difference between 0.7.1 and main branch very big?
@brisker Thanks for your consistent efforts. I did not reproduce bad accuracy on INT8-kv or weight-only, they both produce score of 0.460.
As for SQ problem, there is a bug, you can fix it manually and I will push a MR later. https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/quantization/quantize.py#L165 https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/quantization/quantize.py#L219 Please add an argument of
eps=config.norm_epsilon,for both of lines. Now SQ on LLAMA2-7b can produce score of 0.454.
- Besides, can you also provide the smoothquant commands you have run?