text-generation-inference
text-generation-inference copied to clipboard
Improve inference speed of Santacoder and Starcoder (and others)
I did some extensive investigation, testing and benchmarking, and determined that the following is needed to speedup inference for the Bigcode models (and most of text-gen-inference models:
- Use
FlashAttentionfor prefill only. This is recommended by the authors because theFlashAttentionkernel relies on a high query length to achieve good parallelization, and because FlashAttention needs a lot of extra work on the inputs/outputs/KV caches for each token. - Vectorize as much pre/post-processing operations as possible, i.e. avoid loops (especially for cuda ops). The warpers / logit processors have already been vectorized in #317, and the rest of
causal_lmhas a prototype implementation in #272 (flash_causal_lmis harder to vectorize, but according to the point abovecausal_lmshould be preferable.) - Perform some form of KV cache pre-allocation and key length padding to a multiple of 8. A complete, static pre-allocated tensor adds complications because of the need to concatenate/filter batches, but it's easy to pre-allocate only a few tokens in advance to run the slow concatenation on every N tokens instead of all of them. (Again, this is not doable with
FlashAttention.) Padding the key length to a multiple of 8 also provides a high speedup, so N=8 is a bare minimum (though higher is better. - Compute the
details(logprobs, prefill data, etc.) only when requested (#288). These take a lot of time and force computing the whole model head (see 5. below), but the results are almost always thrown away. - Compute the model head only for the last token in prefill (unless we do need them for
details). This saves some time and more importantly avoids a memory bottleneck. - Use deterministic generation only when a seed is provided. Otherwise, sampling needs to be done in a loop because Pytorch doesn't support vectorized generators.
- Trim the python code. Avoid any unnecessary function call (use inline when possible), attribute getting, etc., as these end up contributing a lot to the CPU latency. Avoid subclassing
nn.Modulebecause it adds a lot of bloat (hooks) on__call__andgetattr. In tests I was able to reduce the santacoder min latency by more than 20% in this way.
Future work (more investigation needed):
- Try and compare more fused kernels. For fused softmax compare Jit (used in #272) and Megatron's implementation (probably better). Compare fused and standard layer norm (results below seem to go against fused). Try fused dense (with gelu) in MLP (or try Jit?)
- Reduce memory allocations by pre-allocating and/or reusing tensors. The main obstacle is that many operations still don't support the
outargument, so some (easy) cpp work would be needed. - Write the cpu-intensive part (
Block) in cpp. This would not be too hard and would help a lot with the latency for smaller models, but may not be needed if cuda graphs are used. - Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch
concatenate/filterdue to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on everyfilter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway. - Look more into tensor parallelism. I know it's already implemented in text-gen-inference, but I haven't looked into it myself.
Some benchmarking results, comparing several implementations:
flash:flash_santacoder, the current implementation.causal: Thegpt_bigcodemodel from HF transformers, run withcausal_lm.vector: Thegpt_bigcodemodel from HF transformers, run withvectorized_causal_lmfrom #272. (Opt. 2 above).bigcode: Thegpt_bigcodemodel from the Bigcode transformers repo, with minor adaptations and trimming to work with text-gen-inference andvectorized_causal_lm(Opt. 1, 3, 4, 5, 6)bigcode2:bigcodewith some additional optimizations taken fromflash_santacoder, mainly theFastLinearandFastLayerNormlayers. Also some simplifications on the attention mask.bigcode3:bigcode2with a trimmed python code (Opt. 7)
Note: flash and causal are based on commit 5a58226 (May 16th) so may be missing the latest optimizations.
Also note: curves are smoothed out, otherwise they oscillate wildly without key length padding (causal and vector)
Santacoder decode
- For batch size=1, CPU is always the bottleneck.
flashis the fastest, and there is a huge difference betweenbigcode1/2/3. Megatron's fused softmax might bringbigcode3andflashnearly on par (I still expectflashto be faster because it has fewer kernels) flashandcausalare really bad at high batch size, especially for long sequences. This is attributable to non-vectorized operations and the poor performance of FlashAttention.vectoralready brings down the batch size overhead to a minimum.bigcode1/2/3show additional improvements.- Surprisingly,
bigcode2/3are slower thanbigcodeforbs=256and large sequences. Attributable to sub-optimal fused layer norm?
Santacoder prefill
causalandvectorare really bad (no FlashAttention)flashis not that great either, it seems attributable to the lots of processing ingenerate_token.bigcode1/2/3work the best and are bery similar (except for bs=1 when CPU-bound).bigcode2/3are marginally better in general (because of fused layer norm?)
Starcoder decode
- Similar to Santacoder, but
flashis already inefficient at a batch size of 1, often even worse thancausal. - Latency for small batch sizes is bottlenecked from reading the weights,(15.5e9 params * 2B/param / 2039e9B/s = 15.2 ms), so tensor parallelism would likely reduce it.
causalgoes crazy for large sequences, not sure why.- Again,
bigcode2/3are worse thanbigcode, suspecting the fused layer norm. - For batch size 256, the times at small seqlen are higher than for smaller batch sizes, suggesting reading the weights is no longer the bottleneck.
Starcoder prefill
- Similar to Santacoder.
bigcode2/3are marginally faster thanbigcodebut run out of memory faster.
@jlamypoirier Thanks for great investigation. """Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""
Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.
@jlamypoirier Thanks for great investigation. """Add support for cuda graphs, at least for decode. I already showed them to work with dynamic shapes (using a lot of graphs), and they add a big speedup for Santacoder (and a small one for Starcoder) but they add complications on batch concatenate / filter due to the static KV cache location. An option would be to always decode with the same batch size (or a few pre-determined values, eg. powers of 2) to costly shuffling the data on every filter, it should be ok since the (Santacoder) decode latency is mostly independent of the batch size anyway."""
Can you show me where did you implement the cuda graphs with dynamic size for SantaCoder? I wonder how it is implemented.
Sorry for the late response, you can find my (messy) implementation in https://github.com/bigcode-project/transformers/blob/main/src/transformers/models/gpt_bigcode/inference_runner.py. Note that this version supports dynamic key lengths but not dynamic batch sizes.
@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128
@jlamypoirier Amazing reports !! May I ask does sequence length indicate max_new_token? I got pretty high latency (about 4s) for starcoder when I set max_new_token=128
It's the time to generate one token. For full time you need to add prefill for context length and generate for range(context_length, context_length + max_new_tokens)
@jlamypoirier These are great suggestions. Have any of these found their way upstream? If not, is your version available anywhere?
edit: especially curious about
Compute the model head only for the last token in prefill (unless we do need them for details). This saves some time and more importantly avoids a memory bottleneck.
This issue is stale because it has been open 30 days with no activity. Remove stale label or comment or this will be closed in 5 days.