sys_reading
sys_reading copied to clipboard
Efficiently Scaling Transformer Inference
https://proceedings.mlsys.org/paper_files/paper/2023/file/523f87e9d08e6071a3bbd150e6da40fb-Paper-mlsys2023.pdf
https://zhuanlan.zhihu.com/p/660715870
summary
key problem
workload
efficient generative inference for Transformer models. (while #256 can be generally applied for all DNN models)
large deep models, with tight latency targets and long sequence lengths
optimization goal
depend on requirements of downstream applications:
- Some applications, including interactive workloads like chatbots, involve tight latency constraints.
- Others, including offline inference for scoring or distillation, emphasize high throughput and low cost per token at any latency.
configurations to tune
model parallelization. how to partition Multi-Head Attention / FFN layer of Transformer block
scenario
datacenter. with TPU
technique
xxxxx
dynamic workload?
xxxxx
multi-tenant?
xxxxx
implementation
xxxxx
Problem and motivation
what is the problem this paper is solving?
why is it important?
why is it challenging?
challenges [ch1]
- large memory footprint of both the trained model parameters and the transient state needed during decoding. need to keep them in GPU memory during inference.
- model parameter can not fit single GPU
- The attention key and value tensors of each layer, which we refer to as the KV cache, must also be stored in memory for the duration of decoding. this state size could be large for long context window.
- [ch2 compute costs] unlike the weights (model parameters), the KV cache is unique for each sequence in the batch.
- [ch2 memory costs] loading them to GPU RAM takes time. At small batch sizes and sequence lengths, the time to load weights dominates. At larger batch sizes and sequence lengths (e.g. 2048+ tokens with batch size 512+), the time to load the KV cache dominates.
- lower parallelizability of Transformer generation relative to training. since inference of later tokens depends on earlier output, see figure from #352 : For example, if input prompt is "ABCD" and LLM outputs "EFG",
- inference cost from the attention mechanism scales quadratically with input sequence length
metrics of inference job [ch2]
-
latency: total time for an inference and can be broken down into
- the time to process the input tokens present at the start of the inference (which we call “prefill”)
- and the time to autoregressively generate output tokens (which we call “decode”). The decode latency can also be measured “per step”, i.e. divided by the number of tokens in each sequence.
- The throughput of prefill or decode is the number of tokens processed or generated per second
- model FLOPS utilization (MFU) is the ratio of the observed throughput to the theoretical maximum throughput if the benchmarked hardware setup were operating at peak FLOPS with no memory or communication overhead. (similar to CPU utilization)
tradeoff space between latency/throughput/cost
ch2.1, Fig1
problem formulation [ch2.2, 3.1]
model
- a Transformer model with
- $n_{params}$ parameters
- embed dimension is $d_{model}$ (or $E$)
- feedforward intermediate dimension $d_{ff}$ (or $F$). 一般F=4*E
- $n_{heads}$ (or $H$) heads.
- $L$: sequence length
device layout [ch3.1]
- $n_{chips}$ chips (GPU/TPU/...) connected with 3D torus topology $X * Y * Z$
- the topology of chips forms a 3D cube. Each chip is connected to its adjacent 6 chips (up, down, left, right, front, back).
- comment: device mesh topology in #256 is 2D mesh, which should be more common in datacenter
tensor partition layouts [ch3.1]
- $BLE_{xyz}$ means the last dimension $E$ of a tensor of logical shape $BLE$ is split into $XYZ$ partitions. Each chip get one partition, which is a tensor with shape $[B, L, E/(XYZ)]$
- If a tensor is replicated over an axis $x$, that axis is omitted from the notation ($BLE_{yz}$).
- 如果一个张量在轴x上被replicated存储,那么该轴将在符号中被省略。
-
partialsum-$x$ means a given tensor has been contracted (summed) locally on each chip (over axis $x$ not represented in the shape), but still needs to be summed across the chips in the TPU $x$ axis (creating a tensor replicated over $x$) before the result is meaningful.
- 表示给定的张量已经在每个芯片上locally地进行了处理(concate或sum),但仍需要在TPU x轴上对芯片进行求和才是最终结果。
communication collectives [ch3.1, Figure A.1 ]
- all-reduce(x): sum up all partialsum-$x$ across chips in the $x$ axis, then broadcasts the sum back to all the involved chips
- reduce-scatter(x): sum up all partialsum-$x$ across chips in the $x$ axis, but the result is shared over all chips, and none of them have a complete result.
- all-gather(x): broadcasts and concatenates the tensor $BLE_{xyz}$ to all chips in the $x$ axis,
- all-to-all: shifts sharding from one tensor dimension to another. similar to key-by in Flink
- also see
- https://zhuanlan.zhihu.com/p/653968730
- https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html
inference stages
An inference request is executed in a batch of $B$ sequences. (for each sequence, we have) $L_{input}$ tokens of input text, and generates $L_{gen}$ tokens of output text. (the input tokens are all present at the start of the inference)
- prefill: run the model over all $B * L_{input}$ many tokens in parallel, in a single forwards pass over all the tokens
- generation/decode: The output tokens are generated autoregressively, with a sequential loop of $L_{gen}$ steps. Each step consists of a single forwards pass through the model, after which we sample one new token for each of the $B$ examples in the batch.
- prefill can run in parallel over $L_{input}$ , but decode must run sequentially over $L_{gen}$
- 前者是计算密集型的,后者是访存密集型的。尤其是,Decoding阶段以难以优化而臭名昭著。因为每次只能解码一个token,计算量本来就很小,还需要访问一个随着解码步数增加而不断累计KVCache,因此严重受限于加速卡的片上带宽。
Main ideas and insights
describe the paper gist in 1-2 sentences
what is important to remember? What did we learn?
provide a set of engineering principles for how best to partition a model in order to scale Transformer inference
Solution description
explain how the solution work
Important results
describe the experimental setup
summarize the main results
Limitations and opportunities for improvement
when doesn't it work?
what assumptions does the paper make and when are they valid?
Closely related work
list of main competitors and how they differ
Follow-up research ideas (Optional)
If you were to base your next research project on this paper, what would you do?
Propose concrete ways to achieve one or more of the following:
Build a better (faster, more efficient, more user-friendly...) system to solve the same problem
Solve a generalization of the problem
Address one of the work's limitations
Solve the same problem in a different context
Solve the problem in a much larger scale
Apply the paper's methods to a different (but similar) problem
Solve a new problem created by this work