sys_reading icon indicating copy to clipboard operation
sys_reading copied to clipboard

Efficiently Scaling Transformer Inference

Open pentium3 opened this issue 1 year ago • 2 comments

https://proceedings.mlsys.org/paper_files/paper/2023/file/523f87e9d08e6071a3bbd150e6da40fb-Paper-mlsys2023.pdf

pentium3 avatar Feb 29 '24 06:02 pentium3

https://zhuanlan.zhihu.com/p/660715870

pentium3 avatar Mar 13 '24 05:03 pentium3

summary

key problem

workload

efficient generative inference for Transformer models. (while #256 can be generally applied for all DNN models)

large deep models, with tight latency targets and long sequence lengths

optimization goal

depend on requirements of downstream applications:

  • Some applications, including interactive workloads like chatbots, involve tight latency constraints.
  • Others, including offline inference for scoring or distillation, emphasize high throughput and low cost per token at any latency.

configurations to tune

model parallelization. how to partition Multi-Head Attention / FFN layer of Transformer block

scenario

datacenter. with TPU

technique

xxxxx

dynamic workload?

xxxxx

multi-tenant?

xxxxx

implementation

xxxxx

Problem and motivation

what is the problem this paper is solving?
why is it important?
why is it challenging?

challenges [ch1]

  • large memory footprint of both the trained model parameters and the transient state needed during decoding. need to keep them in GPU memory during inference.
    • model parameter can not fit single GPU
    • The attention key and value tensors of each layer, which we refer to as the KV cache, must also be stored in memory for the duration of decoding. this state size could be large for long context window.
    • [ch2 compute costs] unlike the weights (model parameters), the KV cache is unique for each sequence in the batch.
    • [ch2 memory costs] loading them to GPU RAM takes time. At small batch sizes and sequence lengths, the time to load weights dominates. At larger batch sizes and sequence lengths (e.g. 2048+ tokens with batch size 512+), the time to load the KV cache dominates.
  • lower parallelizability of Transformer generation relative to training. since inference of later tokens depends on earlier output, see figure from #352 : For example, if input prompt is "ABCD" and LLM outputs "EFG", image
  • inference cost from the attention mechanism scales quadratically with input sequence length

metrics of inference job [ch2]

  • latency: total time for an inference and can be broken down into
    • the time to process the input tokens present at the start of the inference (which we call “prefill”)
    • and the time to autoregressively generate output tokens (which we call “decode”). The decode latency can also be measured “per step”, i.e. divided by the number of tokens in each sequence.
  • The throughput of prefill or decode is the number of tokens processed or generated per second
  • model FLOPS utilization (MFU) is the ratio of the observed throughput to the theoretical maximum throughput if the benchmarked hardware setup were operating at peak FLOPS with no memory or communication overhead. (similar to CPU utilization)

tradeoff space between latency/throughput/cost

ch2.1, Fig1

problem formulation [ch2.2, 3.1]

model

  • a Transformer model with
    • $n_{params}$ parameters
    • embed dimension is $d_{model}$ (or $E$)
    • feedforward intermediate dimension $d_{ff}$ (or $F$). 一般F=4*E
    • $n_{heads}$ (or $H$) heads.
  • $L$: sequence length

device layout [ch3.1]

  • $n_{chips}$ chips (GPU/TPU/...) connected with 3D torus topology $X * Y * Z$
    • the topology of chips forms a 3D cube. Each chip is connected to its adjacent 6 chips (up, down, left, right, front, back).
    • comment: device mesh topology in #256 is 2D mesh, which should be more common in datacenter

tensor partition layouts [ch3.1]

  • $BLE_{xyz}$ means the last dimension $E$ of a tensor of logical shape $BLE$ is split into $XYZ$ partitions. Each chip get one partition, which is a tensor with shape $[B, L, E/(XYZ)]$
  • If a tensor is replicated over an axis $x$, that axis is omitted from the notation ($BLE_{yz}$).
    • 如果一个张量在轴x上被replicated存储,那么该轴将在符号中被省略。
  • partialsum-$x$ means a given tensor has been contracted (summed) locally on each chip (over axis $x$ not represented in the shape), but still needs to be summed across the chips in the TPU $x$ axis (creating a tensor replicated over $x$) before the result is meaningful.
    • 表示给定的张量已经在每个芯片上locally地进行了处理(concate或sum),但仍需要在TPU x轴上对芯片进行求和才是最终结果。

communication collectives [ch3.1, Figure A.1 ]

  • all-reduce(x): sum up all partialsum-$x$ across chips in the $x$ axis, then broadcasts the sum back to all the involved chips
  • reduce-scatter(x): sum up all partialsum-$x$ across chips in the $x$ axis, but the result is shared over all chips, and none of them have a complete result.
  • all-gather(x): broadcasts and concatenates the tensor $BLE_{xyz}$ to all chips in the $x$ axis,
  • all-to-all: shifts sharding from one tensor dimension to another. similar to key-by in Flink
  • also see
    • https://zhuanlan.zhihu.com/p/653968730
    • https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html

inference stages

An inference request is executed in a batch of $B$ sequences. (for each sequence, we have) $L_{input}$ tokens of input text, and generates $L_{gen}$ tokens of output text. (the input tokens are all present at the start of the inference)

  • prefill: run the model over all $B * L_{input}$ many tokens in parallel, in a single forwards pass over all the tokens
  • generation/decode: The output tokens are generated autoregressively, with a sequential loop of $L_{gen}$ steps. Each step consists of a single forwards pass through the model, after which we sample one new token for each of the $B$ examples in the batch.
  • prefill can run in parallel over $L_{input}$ , but decode must run sequentially over $L_{gen}$
  • 前者是计算密集型的,后者是访存密集型的。尤其是,Decoding阶段以难以优化而臭名昭著。因为每次只能解码一个token,计算量本来就很小,还需要访问一个随着解码步数增加而不断累计KVCache,因此严重受限于加速卡的片上带宽。

Main ideas and insights

describe the paper gist in 1-2 sentences
what is important to remember? What did we learn?

provide a set of engineering principles for how best to partition a model in order to scale Transformer inference

Solution description

explain how the solution work

Important results

describe the experimental setup
summarize the main results

Limitations and opportunities for improvement

when doesn't it work?
what assumptions does the paper make and when are they valid?

Closely related work

list of main competitors and how they differ

Follow-up research ideas (Optional)

If you were to base your next research project on this paper, what would you do?
Propose concrete ways to achieve one or more of the following:

Build a better (faster, more efficient, more user-friendly...) system to solve the same problem
Solve a generalization of the problem
Address one of the work's limitations
Solve the same problem in a different context
Solve the problem in a much larger scale
Apply the paper's methods to a different (but similar) problem
Solve a new problem created by this work

pentium3 avatar Mar 15 '24 20:03 pentium3