TensorRT-LLM
TensorRT-LLM copied to clipboard
add `chunk_length` parameter to Whisper
distil-whisper models perform the best with chunk sizes less than 30s that the original whisper models use, this PR introduces the option to build the engine with a different chunk length
Summary of the changes in this PR:
- Whisper encoder now supports changing chunk_size
- Example has been updated to support
remove_input_paddingin the decoder conv1dnow supports input with more than 1 dynamic shape- Whisper decoder should now support inflight batching when built with
paged_kv_cacheusing the executor, although there is no clear way to feed the encoder input and the prompt to thetensorrt_llm.bindings.Requestclass as it only accepts list of tokens in all inputs, and the encoder output is a float tensor
enabling remove_input_padding in the encoder wasn't as easy as I thought, all of my trials failed at the step where the positional embeddings are added to the conv output. chunk size is not defined at build time, this didn't work because the positional embeddings tensor first dim is 1500 which corresponds to 30s inputs. When the chunk_size is known at build time it's easy to slice the positional embeddings tensor to the correct size and add it to the conv output, but when the chunk size is unknown, the build fails at fetching the correct indices, for example:
import tensorrt_llm.functional as F
positional_embeddings = F.gather(
positional_embedding,
dim=0,
indices=F.concat(
[F.arange(0, input_length, "int32") for input_length in input_lengths.unbind()]
),
)
## only for padded input
positional_embeddings = F.view(positional_embeddings,[-1, chunk_size, hidden_size])
##
x = x + positional_embeddings
input_lengths.unbind() fails because input_lengths shape is [-1]
removing input padding from the encoder isn't that much important TBH as we expect encoder inputs to be of the same shape and size except for the last window in an audio, it will be beneficial in scenarios where we expect the requests to be multiple audio files which all of them are less than 30s and vary a lot in length
on the other side, remove_input_padding is important on the decoder side because it's required to enable inflight batching, from a quick trial on a 30 min audio file, the larger the batch size, the slower the generation
efficiency = generation loops needed / actual generation loops (calculated by the longest seq in the output)
# time is for decoding only, the whole 30 mins
# efficiency: 1.0 / batch_size=1
# 626 ms ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency: 0.87 / batch_size=2
# 716 ms ± 50.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency: 0.80 / batch_size=4
# 755 ms ± 31.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency: 0.67 / batch_size=8
# 1.05 s ± 21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency: 0.56 / batch_size=16
# 1.33 s ± 38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# efficiency: 0.45 / batch_size=32
# 1.85 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
as we notice, the time taken increases with batch size which is counter productive for large workloads, hence the need for inflight batching
@MahmoudAshraf97 Hi, thanks for your effort. I would take this PR into our internal gitlab. Also, we would add your name into the co-author list and credit your work on the release notes for whisper IFB feature.
@MahmoudAshraf97 Hi, I just tried the more than 1 dynamic shape conv1d solution by setting codes below:
x = Tensor(name="x",
dtype=self._dtype,
shape=[-1, self.config.n_mels, -1],
dim_range=OrderedDict([
("batch_size", [bs_range]),
("feature_dim", [self.config.n_mels]),
("feature_len_range", [1, 1000, 3000]),
]))
However, the build process failed. Seems the slice operator would need to know the value of x.shape[1].
I was wondering why you set fixed config.chunk_length here rather than let it be dynamic.
@MahmoudAshraf97 Hi, I just tried the more than 1 dynamic shape conv1d solution by setting codes below:
x = Tensor(name="x", dtype=self._dtype, shape=[-1, self.config.n_mels, -1], dim_range=OrderedDict([ ("batch_size", [bs_range]), ("feature_dim", [self.config.n_mels]), ("feature_len_range", [1, 1000, 3000]), ]))However, the build process failed. Seems the slice operator would need to know the value of x.shape[1].
I was wondering why you set fixed config.chunk_length here rather than let it be dynamic.
as I mentioned in my trials in the PR, this was a step to make it work but I couldn't complete it because of the slice operator or other operators that aim to add the positional embeddings to x. before this change the build failed at the first conv layer, now it passes the conv layers and fails at a later stage, so I guess we are half way there
as I mentioned in my trials in the PR, this was a step to make it work but I couldn't complete it because of the slice operator or other operators that aim to add the positional embeddings to
x. before this change the build failed at the first conv layer, now it passes the conv layers and fails at a later stage, so I guess we are half way there
@MahmoudAshraf97 I see. Thanks. Btw, the remove_input_padding for decoder issue has been fixed. The code would sync to github one week later.
closing this since it was merged