tensorrtllm_backend icon indicating copy to clipboard operation
tensorrtllm_backend copied to clipboard

Whisper one-shot enc+dec path treats mel frames (3000) as “encoder length”, causing length assertion at 1500 and internal broadcast shape errors

Open YuBeomGon opened this issue 3 months ago • 1 comments

Environment

Triton Server: nvcr.io/nvidia/tritonserver:25.08-trtllm-python-py3

TRT-LLM version inside server: 1.1.0rc1

Engines built with TRT-LLM: 1.1.0rc1 (also tried 1.1.0rc5 in devel)

Model: Whisper large-v3-turbo (TRT-LLM encoder+decoder)

Expected

For Whisper, input mel length is 3000 (30s @ hop=160), but encoder internal length is 1500 due to conv stride=2.

The one-shot enc+dec path should accept encoder_input_features with shape [B,3000,128] and input_lengths=1500, and route to decoder without errors (like the official Session-based flow).

Actual

Case A (encoder max_seq_len=1500, default):

[TensorRT-LLM][ERROR] Assertion failed: Encoder length (3000) exceeds maximum encoder input length (1500).

Case B (encoder forced to max_seq_len=3000):

Invalid input shape ... ELEMENTWISE_SUM: Broadcast incompatible: 6000 != 3000 ...

(internal shape mismatch likely between position embedding and conv outputs)

Repro Steps trtllm-build
--checkpoint_dir /ws/tllm_ckpt_whisper/encoder
--output_dir /ws/engines/whisper_large_v3_turbo/encoder
--max_batch_size 16
--max_input_len 1500
--bert_attention_plugin float16

trtllm-build
--checkpoint_dir /ws/tllm_ckpt_whisper/decoder
--output_dir /ws/engines/whisper_large_v3_turbo/decoder
--max_batch_size 16
--max_encoder_input_len 1500
--max_input_len 224
--max_seq_len 448
--max_beam_width 5
--gather_generation_logits
--logits_dtype float16
--gpt_attention_plugin float16
--kv_cache_type paged
--gemm_plugin float16

Build encoder with --max_input_len 1500 (or leave deduced), decoder with --max_encoder_input_len 1500.

In Triton tensorrtllm model, send:

encoder_input_features: [B,3000,128] FP16

input_lengths: 1500

Observe assertion error above.

Rebuild encoder with --max_seq_len 3000 to bypass assertion, observe internal broadcast error.

Notes

Session-based flow (run_whisper.py) works: encoder accepts mel 3000 and returns encoder_output_lengths=1500 to decoder – no issues.

The one-shot path seems to validate “encoder length” against mel length instead of downsampled length.

Questions

Is this a known limitation for the one-shot enc+dec path with Whisper?

Can the length validation and internal shape setup be updated to account for encoder stride (×2 downsampling) so that the one-shot path matches the Session-based flow?

Any recommended flags or configuration to make one-shot work without external mel downsampling?

YuBeomGon avatar Sep 29 '25 07:09 YuBeomGon

Try this example it works just fine:

https://github.com/triton-inference-server/tensorrtllm_backend/blob/main/docs/whisper.md

protonicage avatar Oct 02 '25 07:10 protonicage