petals icon indicating copy to clipboard operation
petals copied to clipboard

Llama: Merge query/key/value projection layers

Open mryab opened this issue 10 months ago • 0 comments

This PR makes an ~7% optimization of the inference throughput (measured on a single A100-80GB) by merging the query/key/value projections into a single large matrix multiplication. This reduces the overhead of launching several matmul kernels, which turns out to be substantial for single-sequence single-token inference steps. Also, this code adds a --throughput dry_run option to estimate throughput without starting a server.

Sample results from running experiments with and without the optimization (the command in each case is CUDA_VISIBLE_DEVICES=0 python -m petals.cli.run_server petals-team/StableBeluga2 --throughput dry_run):

Current code (branch https://github.com/bigscience-workshop/petals/tree/no_qkv_merge):

Sep 03 00:50:35.135 [INFO] Inference throughput: 532.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:50:47.722 [INFO] Forward pass throughput: 51749.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:52:07.524 [INFO] Inference throughput: 576.4 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:52:20.919 [INFO] Forward pass throughput: 36552.9 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:53:54.616 [INFO] Inference throughput: 512.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:54:14.464 [INFO] Forward pass throughput: 50242.5 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Code from this PR:

Sep 03 00:55:25.680 [INFO] Inference throughput: 564.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:55:38.648 [INFO] Forward pass throughput: 33023.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:56:45.526 [INFO] Inference throughput: 578.4 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:56:59.632 [INFO] Forward pass throughput: 54655.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:58:18.783 [INFO] Inference throughput: 593.1 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:58:33.015 [INFO] Forward pass throughput: 36200.4 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

mryab avatar Sep 02 '23 21:09 mryab