TensorRT-LLM
TensorRT-LLM copied to clipboard
Add weight-only quantization for T5 models
Summary
Add weight-only quantization for T5.
I've added this to the path loading from binary weights. I do not think the HF weight loading currently works, so I have not added this functionality to this path yet. But please let me know if I've misunderstood.
Test Plan
Built the engine with t5-small using the following commands
python t5/convert.py -i t5-small -o $FT_WEIGHT_DIR --weight_data_type float32 --inference_tensor_para_size 1
and then
python build.py --model_type t5 \
--weight_dir ${FT_WEIGHT_DIR}/tp1 \
-o $TRT_ENGINE_DIR \
--engine_name t5 \
--use_bert_attention_plugin \
--use_gpt_attention_plugin \
--use_gemm_plugin \
--dtype float16 \
--max_beam_width 4 \
--max_batch_size 32 \
--max_output_len 512 \
--use_weight_only \
--weight_only_precision int8
Then we measured the performance using
python run.py --engine_dir $TRT_ENGINE_DIR --engine_name t5-small --model_name t5-small --max_new_token=64 --num_beams=1 --compare_hf_fp32
Model Sizes
fp16
Encoder: 69Mb Decoder: 112Mb
int8 quant
Encoder: 51Mb Decoder: 88Mb
int4 quant
Encoder: 42Mb Decoder: 76Mb
Benchmark Results
No quantization (fp16)
TRT-LLM E2E time 233.34026336669922ms TRT-LLM results match HF FP32 results with literal match rate 0.9255813953488372
Weight-only int8 (fp16)
TRT-LLM E2E time 215.11101722717285ms Match rate 0.784037558685446
Weight-only int4 (fp16)
TRT-LLM E2E time 203.82308959960938ms Match rate 0.38095238095238093
Follow-ups
- Fix issue with FP16 inference with T5 models. (https://github.com/huggingface/transformers/issues/20287). I believe we need fp16 for woq to work, so this will be needed to be useful in production.
- May consider adding flag to quantize decoder only. This greatly improves on the match quality, so can give more granular control over quality vs. latency trade-off.
- Support remaining quantization schemes (SmoothQuant, AWQ, GPTQ).
- Add this for other enc_dec models outside of T5.
cc: @symphonylyh can you please help me take a look at this MR for the first part of T5 quantization?
Make a change in tensorrt_llm/models/enc_dec/model.py like the picture may solve the m=0 error.
thank you so much @Eddie-Wang1120! this indeed fixed the build issue I was having and now the output looks correct. This has unblocked me for the next steps now.
@eycheung thanks for the contribution, and thanks @Eddie-Wang1120 for the advice (I lost track of what the m=0 error was since the author has modified the PR description I think)
@eycheung for Follow-up 1, internally we're working on FP16 fix for T5 family, where we need 2 overflow clamps around the layernorm/rmsnorm. For example, right now the match rate is > 0.9 for T5-small, but when you work on larger models T5-3B, the overflow issue will be significant without clamping.
As a next step, we'll take your current PR internally, CI tested it, and acknowledge your contribution when it's released. I agree with you idea that we keep this one self-contained for quanatization, and address Follow-up 2 to 4 in separate PRs
Thanks for your brilliant work @eycheung, and thanks for your time @symphonylyh. I noticed @symphonylyh lost track of the m=0 bug. I happend to write an detailed explanation of this bug in https://github.com/NVIDIA/TensorRT-LLM/pull/992, you can read it for more information if interested.
Great, thank you @symphonylyh ! And I appreciate the update and work on follow-up 1.
And thank you @Eddie-Wang1120 for the help on this. To summarize the issue, before the @Eddie-Wang1120 's fix, the build would hang when enabling the weight_only_gemm_plugin in the profiling step
[TensorRT-LLM][WARNING] Cannot profile configuration 59 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 60 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 61 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 62 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Have not found any valid GEMM config for shape (m=0, n=3840, k=1280). Will try to use default or fail at runtime
I'm trying to run this PR, but I'm getting an assertion error when running the build.py script even without weight-only quantization (with FP32 too). Couldn't figure out the reason. Development GPU is RTX 3090. Do you have any idea why?
Traceback (most recent call last):
File "/root/tensorrt-llm/examples/enc_dec/build.py", line 635, in <module>
run_build(component='decoder')
File "/root/tensorrt-llm/examples/enc_dec/build.py", line 626, in run_build
build(0, args)
File "/root/tensorrt-llm/examples/enc_dec/build.py", line 581, in build
engine = build_rank_engine(builder, builder_config, engine_name,
File "/root/tensorrt-llm/examples/enc_dec/build.py", line 512, in build_rank_engine
tllm_model(*inputs)
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__
output = self.forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/models/enc_dec/model.py", line 960, in forward
hidden_states = decoder_layer(
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__
output = self.forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/models/enc_dec/model.py", line 400, in forward
attention_output = self.self_attention(
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__
output = self.forward(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/layers/attention.py", line 741, in forward
context, past_key_value = gpt_attention(
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/graph_rewriting.py", line 561, in wrapper
outs = f(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 3778, in gpt_attention
assert kv_cache_block_pointers is not None, "Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not
be None"
AssertionError: Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not be None
I'm trying to run this PR, but I'm getting an assertion error when running the
build.pyscript even without weight-only quantization (with FP32 too). Couldn't figure out the reason. Development GPU is RTX 3090. Do you have any idea why?Traceback (most recent call last): File "/root/tensorrt-llm/examples/enc_dec/build.py", line 635, in <module> run_build(component='decoder') File "/root/tensorrt-llm/examples/enc_dec/build.py", line 626, in run_build build(0, args) File "/root/tensorrt-llm/examples/enc_dec/build.py", line 581, in build engine = build_rank_engine(builder, builder_config, engine_name, File "/root/tensorrt-llm/examples/enc_dec/build.py", line 512, in build_rank_engine tllm_model(*inputs) File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__ output = self.forward(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/models/enc_dec/model.py", line 960, in forward hidden_states = decoder_layer( File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__ output = self.forward(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/models/enc_dec/model.py", line 400, in forward attention_output = self.self_attention( File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/module.py", line 40, in __call__ output = self.forward(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/layers/attention.py", line 741, in forward context, past_key_value = gpt_attention( File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/graph_rewriting.py", line 561, in wrapper outs = f(*args, **kwargs) File "/opt/conda/lib/python3.10/site-packages/tensorrt_llm/functional.py", line 3778, in gpt_attention assert kv_cache_block_pointers is not None, "Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not be None" AssertionError: Paged kv cache is enabled, the kv_cache_block_pointers tensor shall not be None
@monatis Can you confirm what version of tensorrt_llm you are building this with? And did you rebuild the library from source before running this script? I am just trying to see if your version of trt-llm is newer than what I tested with, since I was not aware there was paged KV cache support for encoder-decoder when I wrote this.
@eycheung It was pre-released 0.9.0. Now I'll try to extract EncDec module to my user scripts and then modify examples to import it instead, bringing your PR and the latest commit in main together.
Great, thank you @monatis ! Sorry i could not help with this specific issue though.