TensorRT-LLM
TensorRT-LLM copied to clipboard
Fix enc_dec bug and Make several improvements to whisper
Thanks to the brilliant work for NVIDIA team! I made some changes to Tensorrt-LLM and hope to get some advice!
Pull Request Intro
This Pull Request include several points:
- fix a bug in enc_dec model which will lead to a build error when the model has cross_attention and use weight_only_gemm_plugin at the same time.
- ban layernorm plugin otherwise it will brings a severe memory usage increase for whisper fp16 inference (16000MiB to 8000MiB).
- add int4 weight-only support to whisper
- gives a base implementation for whisper int8_kv_cache (half-way finished due to an internal error)
What is the bug and What I do
Bug intro
The bug can be make a reproduction in previous version when add weight_only_gemm_plugin to whisper decoder model. The expected behaviour is to pass building correctly. However, when it comes to profiling in building step, errors as below will show in log, and build will ended up as failing.
[TensorRT-LLM][WARNING] Cannot profile configuration 59 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 60 (for m=0, n=3840, k=1280), reason: "[TensorRT-LLm Error][fpA_intB Runner] Failed to run cutlass fpA_intB gemm. Error: Error Internal". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 61 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Cannot profile configuration 62 (for m=0, n=3840, k=1280), reason: "Temp assertion: k must be multiple of threadblockK". Skipped
[TensorRT-LLM][WARNING] Have not found any valid GEMM config for shape (m=0, n=3840, k=1280). Will try to use default or fail at runtime
How to solve
After I rebuild the whisper decoder model (which Inherits from enc_dec DecoderModel) layer by layer, I find the error only happens when the model has a cross attention. More suspiciously, when checking the prepare_inputs function in DecoderModel, a variable called encoder_input_len_range caught my eyes, for it is a dim range be used by several special inputs for cross_attention and the min range is 0 which exactly explains why there are m=0 logs in building process.
encoder_input_len_range = [
0, (max_encoder_input_len + 1) // 2, max_encoder_input_len
]
In my opinion, the min value of encoder_input_len_range does not have to be 0 because it is not like kv-cache which needs to be concatenate. After I change it to 1, the building process passed successfully and the results maintain correction. Now, the enc_dec model all can use weight_only_gemm_plugin and enjoy the performance improvements freely.
About LayerNorm plugin
Banning LayerNorm plugin is always a top mission for it is going to be deprecated. A main reason why it still be retained in the previous version is because simply banning it will lead to a building failure. In this version, banning it no longer bring any errors and brings multiple benefits. Most clearly, the memory usage of whisper fp16 inference decreases from 16030MiB to 8000MiB, means the whisper can be inference by Tensorrt-LLM in more devices.
About int8_kv_cache
It's a pity that the int8_kv_cache for whisper model still not finished. The building process seems correctly. When it comes to the inference step, an internal error occurs. After I tried all ways I can imagined, it still preserved. I create an issue for this bug https://github.com/NVIDIA/TensorRT-LLM/issues/993 and display detailed bug information in it. Anyone is interested and has an idea please let me know, I sincerely hopes this error can be solved at an early date, thanks you all in advance.
Performance
| \ | float16 (with layernorm plugin) | float16 | int8 weight-only | int4 weight-only |
|---|---|---|---|---|
| GPU memory usage | 16030MiB | 8186MiB | 6717MiB | 6036MiB |
| RTF | 0.1962 | 0.0542 | 0.0488 | 0.0473 |
| processing time | 94.397s | 26.066s | 23.492s | 22.741s |
| batch_size | 4 | 4 | 4 | 4 |
| num_beams | 1 | 1 | 1 | 1 |
| WER | 2.99 | 2.99 | 2.82 | 6.33 |
Environment
- intel i5 13500
- nvidia 4060ti 16G
- Tensorrt-LLM commitID b57221b764bc579cbb2490154916a871f620e2c4
- container nvidia-docker run --entrypoint /bin/bash -it nvidia/cuda:12.1.0-devel-ubuntu22.04
This is really nice work! Many thanks to you @Eddie-Wang1120. I would import this into internal gitlab and hopefully it could be done this week.
This is really nice work! Many thanks to you @Eddie-Wang1120. I would import this into internal gitlab and hopefully it could be done this week.
Thanks a lot!
Thanks for the awesome contributions from you two!
Adding some of my minor observations relevant to this:
-
Layer norm issue: +1, I observed similar behavior.
-
bert_attention_pluginweird behavior:- On A30 GPU, setting
network.plugin_config.set_bert_attention_plugin(dtype=args.use_bert_attention_plugin)negatively impacts performance. - Following the
bert_attention_plugindocs, I addednetwork.plugin_config.set_context_fmha(ContextFMHAType.enabled)to the builder config, which improved inference speed but increased memory usage. - Interestingly, simply disabling
bert_attention_pluginachieves similar speed with lower memory usage. - On T4 GPU, unlike A30, disabling
bert_attention_pluginoutperforms using bothbert_attention_pluginandcontext_fmha.
- On A30 GPU, setting
Thanks for your advices! @shashikg Following your observations, I disabled bert_attention_plugin and got some results:
with bert_attention_plugin
| \ | float16 | int8 weight-only | int4 weight-only |
|---|---|---|---|
| GPU memory usage | 8186MiB | 6717MiB | 6036MiB |
| RTF | 0.0542 | 0.0488 | 0.0473 |
| processing time | 26.066s | 23.492s | 22.741s |
| batch_size | 4 | 4 | 4 |
| num_beams | 1 | 1 | 1 |
| WER | 2.99 | 2.82 | 6.33 |
disable bert_attention_plugin
| \ | float16 | int8 weight-only | int4 weight-only |
|---|---|---|---|
| GPU memory usage | 6065MiB | 5696MiB | 5024MiB |
| RTF | 0.0464 | 0.0489 | 0.0465 |
| processing time | 22.314s | 23.530s | 22.379s |
| batch_size | 4 | 4 | 4 |
| num_beams | 1 | 1 | 1 |
| WER | 2.99 | 2.82 | 6.33 |
The results shows that disable bert_attention_plugin indeed decrease memory usage, and may improve inference speed at some situations. Maybe we should consider using this plugin cautiously.
@Eddie-Wang1120 great to know about the encoder_input_len_range issue when used together with weight only gemm plugin, I agree with your fix that the min value doesn't need to be 0 in all cases.
@Eddie-Wang1120 @shashikg general guidance on layernorm and bert plugin usage:
- For LayerNorm/RMSNorm plugin, it's in deprecation mode, so it's recommended to do without these normalization plugins
- For BERT plugin,
First, it should be always used together with
--enable_context_fmha, otherwise the comparison is not fair because it's using the unfused multi-head attention implementation Second, regarding w/ and w/o BERT plugin, we have done some investigation and 3 takeaways:
- Peak memory usage wise, w/o BERT plugin is indeed better than w/ BERT plugin. If peak memory is a restriction, consider use w/o BERT plugin
- Performance wise, on BERT example itself, w/o and w/ plugin paths are on par based on our benchmark. However, we mainly tested on newer GPUs such as Ampere and Hopper. It's possible on older ones like T4 you observed a different trend. In that case, it's recommended to try both on your specific GPU and decide
- Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.
Thank you so much @symphonylyh for the guidelines!
Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.
I see... Based on this I think now it make sense why w/ BERT plugin, performance on Whisper model does not improves (because I was running the inference on fixed 30 seconds input). So the whisper model is trained on fixed 30 seconds audios and during inference as well it expects to receive a 30 seconds audio. Even if an audio is smaller than 30 seconds and if we run the whisper's encoder on it without padding the input audio to 30 seconds, whisper's decoder falls more frequently in generating hallucinated outputs/ or repeated texts. So basically the inputs to whisper's encoder will always be of same length.
Thank you so much @symphonylyh for the guidelines!
Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.
I see... Based on this I think now it make sense why w/ BERT plugin, performance on Whisper model does not improves (because I was running the inference on fixed 30 seconds input). So the whisper model is trained on fixed 30 seconds audios and during inference as well it expects to receive a 30 seconds audio. Even if an audio is smaller than 30 seconds and if we run the whisper's encoder on it without padding the input audio to 30 seconds, whisper's decoder falls more frequently in generating hallucinated outputs/ or repeated texts. So basically the inputs to whisper's encoder will always be of same length.
@shashikg We actually could remove the padding 30s restriction of encoder, see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py#L15. It would save cross kv cache VRAM usage as well. However, there is a bug now if we set conv subsampling layers in encoder with dynamic seq_len dim.
@Eddie-Wang1120 great to know about the encoder_input_len_range issue when used together with weight only gemm plugin, I agree with your fix that the min value doesn't need to be 0 in all cases.
@Eddie-Wang1120 @shashikg general guidance on layernorm and bert plugin usage:
- For LayerNorm/RMSNorm plugin, it's in deprecation mode, so it's recommended to do without these normalization plugins
- For BERT plugin, First, it should be always used together with
--enable_context_fmha, otherwise the comparison is not fair because it's using the unfused multi-head attention implementation Second, regarding w/ and w/o BERT plugin, we have done some investigation and 3 takeaways:
- Peak memory usage wise, w/o BERT plugin is indeed better than w/ BERT plugin. If peak memory is a restriction, consider use w/o BERT plugin
- Performance wise, on BERT example itself, w/o and w/ plugin paths are on par based on our benchmark. However, we mainly tested on newer GPUs such as Ampere and Hopper. It's possible on older ones like T4 you observed a different trend. In that case, it's recommended to try both on your specific GPU and decide
- Lastly, from a practical perspective, w/o BERT plugin path has a limitation on padding removal -- that is, when you have ragged input, e.g., batch size = 2, text1 is length 10, text2 is length 100, w/ BERT plugin path can do padding removal by effectively doing a computation of length 10+100=110 text (the BERT example currently doesn't demonstrate this, which I plan to add and clarify this point), while the w/o BERT plugin path can only do computation on the padded one, so equivalently 100+100=200 text. This could make a big difference in real deployment. If this is a concern, this last point would become a deciding factor to favor the w/ plugin path.
Thanks for the guidelines! @symphonylyh
@shashikg We actually could remove the padding 30s restriction of encoder, see https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py#L15. It would save cross kv cache VRAM usage as well.
Hey yes, I agree and most probably this should improve the inference time. I have tested dynamic seq_len in my project "WhisperS2T" (https://github.com/shashikg/WhisperS2T/blob/main/whisper_s2t/backends/init.py#L35) with CTranslate2 backend but currently it's in experimental phase (so can break thus not included in docs).
So my concern is not in whether we can run it or not. If we infer with dynamic seq len , what I observed is that whisper's decoder makes more error in generated text output (mostly non-stopping repeated text tokens). Definitely there are various heuristics we can use to work around. But after adding those heuristics inference time will increase. Moreover specifically these non-stopping repeated tokens also increase the generation time significantly. Definitely this issue can be avoided by fine-tuning the model on dynamic seq_len which openai didn't do for some reason.
However, there is a bug now if we set conv subsampling layers in encoder with dynamic seq_len dim.
I am curious what's the exact issue, normally the patch should work. I have tried out a similar thing in past. One issue I can think of is because of detect_language function, check this: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L51 -- this check will create issue if you use detect language function with dynamic seq len.
So my concern is not in whether we can run it or not. If we infer with
dynamic seq len, what I observed is that whisper's decoder makes more error in generated text output (mostly non-stopping repeated text tokens). Definitely there are various heuristics we can use to work around. But after adding those heuristics inference time will increase. Moreover specifically thesenon-stopping repeated tokensalso increase the generation time significantly. Definitely this issue can be avoided by fine-tuning the model on dynamic seq_len which openai didn't do for some reason.
Yes, one of the heuristics is to pad 50 frames at the end. https://github.com/k2-fsa/sherpa-onnx/pull/471
I am curious what's the exact issue, normally the patch should work. I have tried out a similar thing in past. One issue I can think of is because of
detect_languagefunction, check this: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L51 -- this check will create issue if you use detect language function with dynamic seq len.
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/functional.py#L2813-L2814 This view operation has some issue. I think it should be a small fix to handle it.
@yuekaizhang I have less background on the Whisper discussion here, but do you mean the current functional.py::conv2d() cannot handle dynamic axes due to the output.view(concat([output.size(1), output.size(2), output.size(3)])) call?
If I understand correctly, this call is doing a squeeze call to remove the 1st dimension, as symmetric to the unsqueeze(input) call before. In that case, do you think select(output, dim=0, index=0) can do the same and meanwhile support dynamic axis?
Update: please use more general squeeze implementation for now, add to functional.py
def squeeze(input: Tensor, dim: Union[int, Sequence[int]] = None):
if dim is None:
dim = list(range(input.ndim()))
if isinstance(dim, int):
dim = (dim, )
new_shape = []
for i, s in enumerate(input.shape):
if s == 1 and i in dim:
continue
new_shape.append(shape(input, i))
input = input.view(concat(new_shape))
return input
@yuekaizhang I have less background on the Whisper discussion here, but do you mean the current
functional.py::conv2d()cannot handle dynamic axes due to theoutput.view(concat([output.size(1), output.size(2), output.size(3)]))call?If I understand correctly, this call is doing a
squeezecall to remove the 1st dimension, as symmetric to theunsqueeze(input)call before. In that case, do you thinkselect(output, dim=0, index=0)can do the same and meanwhile support dynamic axis?
Thanks, I would try your suggestion and give feedback to you. @shashikg @symphonylyh
Added a data point using A16 GPU. Batch_size 4, num_beam 1
| FP16 | Weight-only-quant int8 |
|---|---|
| 35 secs Decoding Time | 33 secs Decoding Time |
| 2.48% Word Error Rate | 2.48% Word Error Rate |
| 5.5 GB VRAM | 4 GB VRAM |