Condition to achieve linear speedup?
I tested latency of QuantLinear forward with various sizes of input and feature sizes. But for token counts from 1 to 1024, I cannot see any speedup compared to AWQ W4A16 kernel and the results were suboptimal to pytorch FP16 Linear in most cases. I tested weight sizes (4096, 4096), (5120, 5120), (6656, 6656), (8192, 8192) which match linear sizes of LLaMA model family on A6000 and RTX3090 GPU. I see the experiments in the paper was taken on A100 GPU. Is there any specific setting or condition to see the speedup aligns with the results on paper?
Overhead of activation quantization using simple PyTorch operation is substantial but the kernel itself is slower than nn.Linear for most cases.
@jiwonsong-dev There is online activation quantization using simple PyTorch in QuantLinear, which is very slow. The GEMM speedup in our paper is evaluated without activation quantization. If you want to reproduce the speedup, please refer to https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604. By the way, the activation quantizaiton is fused into element-wise kernel like rmsnorm in our vllm PR, and it will not affect the inference speed much.
Is the kernel integrated to vLLM is the same one in the repo? I see the QuantLinear slower than nn.Linear for M from 1 to 1024 when N,K are fixed to 4096 even with the quantization overhead not considered.
@jiwonsong-dev The kernel is the same with that in vLLM. If there is no other operations like dtype conversion and reshape in your modified QuantLinear, the QuantLinear should deliver the similar performance with directly using the gemm kernel. Generally, the QuantLinear is only used for the simple inference in our repo. I recommend you to try vLLM for practical inference.
I checked your fork of Marlin repository and saw actual speedup via benchmark codes. Thank you for kind response!
Is there any specific reason why permutation is different when packing channel quantized weights? Per group follows original Marlin format.
@jiwonsong-dev It is relevant with the mma instruction's requirements https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type.
@HandH1998 I have tried QQQ-w4a8-no-group version on internVL-20B on my own task, the embarrassing thing is that, compared to w8a8, the w4a8 is faster on decoding speed indeed as expected, but slower on first-token generation. But due to the tradeoff between first token and decoding, the final speed of w4a8 is even slightly slower than w8a8.
The puzzle from me is that, I am already using the w4a8-per-channel version, with no group, why is w4a8 first-token is still such slow? According to your paper, the decoding process from w4 into w8 to do w8a8 gemm, is simply multiply 16, which should not be so slow.
Have you ever analyzed details like this for your w4a8-no-group kernel? Any further advice to optimize the kernel?
@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.
input length=1024 output length=128 vLLM llama-2-series
TTFT(ms)
| bsz | fp16 | sq-w8a8 | awq-g128 | marlin | marlin-g128 | qqq | qqq-g128 |
|---|---|---|---|---|---|---|---|
| 7b | |||||||
| 1 | 71.63 | 46.71 | 104.29 | 71.94 | 78.35 | 53.44 | 65.99 |
| 4 | 270.68 | 171.24 | 293.73 | 285.65 | 313.15 | 208.25 | 261.54 |
| 16 | 274.95 | 175.09 | 299.09 | 290.64 | 318.39 | 212.36 | 266.04 |
| 64 | 294.64 | 198.12 | 318.90 | 315.41 | 343.35 | 238.14 | 290.36 |
| 13b | |||||||
| 1 | 133.48 | 78.37 | 204.17 | 132.94 | 146.03 | 90.93 | 117.30 |
| 4 | 241.39 | 155.24 | 312.20 | 265.77 | 293.79 | 180.91 | 234.62 |
| 16 | 245.51 | 158.43 | 316.91 | 269.22 | 297.37 | 184.73 | 238.52 |
| 64 | 285.86 | 180.74 | 337.22 | 289.07 | 317.10 | 204.47 | 257.47 |
| 70b | |||||||
| 1 | - | 356.51 | 992.45 | 662.55 | 756.93 | 417.54 | 571.21 |
| 4 | - | 1400.66 | 2766.41 | 2627.26 | 3010.08 | 1674.70 | 2292.46 |
| 16 | - | 1402.62 | 2775.86 | 2635.86 | 3016.53 | 1682.51 | 2296.78 |
| 64 | - | 1425.73 | 2807.03 | 2661.51 | 3023.35 | 1712.25 | 2326.96 |
TPOT(ms)
| bsz | fp16 | sq-w8a8 | awq-g128 | marlin | marlin-g128 | qqq | qqq-g128 |
|---|---|---|---|---|---|---|---|
| 7b | |||||||
| 1 | 11.70 | 15.81 | 9.32 | 6.33 | 6.47 | 6.34 | 6.59 |
| 4 | 12.78 | 17.33 | 10.51 | 7.39 | 7.45 | 7.47 | 7.74 |
| 16 | 24.62 | 23.34 | 23.64 | 19.59 | 20.40 | 17.86 | 19.37 |
| 64 | 71.74 | 57.75 | 82.38 | 70.03 | 74.03 | 62.28 | 69.45 |
| 13b | |||||||
| 1 | 20.10 | 18.29 | 14.39 | 8.99 | 9.18 | 8.94 | 9.31 |
| 4 | 23.76 | 22.12 | 18.77 | 12.85 | 13.24 | 12.18 | 12.90 |
| 16 | 43.33 | 34.32 | 42.85 | 33.15 | 34.98 | 28.45 | 31.72 |
| 64 | 146.53 | 92.54 | 151.86 | 117.41 | 125.15 | 96.60 | 111.25 |
| 70b | |||||||
| 1 | - | 54.27 | 50.79 | 29.29 | 30.20 | 28.88 | 30.48 |
| 4 | - | 61.20 | 54.16 | 32.06 | 32.84 | 31.22 | 32.98 |
| 16 | - | 160.92 | 135.19 | 104.32 | 114.00 | 80.27 | 96.51 |
| 64 | - | 526.42 | 546.04 | 408.20 | 453.59 | 283.48 | 363.82 |
@HandH1998 What does TTFT(ms) and TPOT(ms) actually mean in your chart?
TTFT: Time To First Token TPOT: Time Per decoding Output Token
@HandH1998
For sq-w8a8 in your chart, which specific kernel are you refering?
In my experiments, I used the official w8a8 kernel from vLLM(cutlass backend).
cublas w8a8 gemm from https://github.com/vllm-project/vllm/pull/1508. But cublas and cutlass should have similar performance.
TPOT: Time Per decoding Output Token
TPOT has already includes the first decoding time? or you have excluded first token time away?
It doesn't include the first token.
@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.
input length=1024 output length=128 vLLM llama-2-series
TTFT(ms)
bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 71.63 46.71 104.29 71.94 78.35 53.44 65.99 4 270.68 171.24 293.73 285.65 313.15 208.25 261.54 16 274.95 175.09 299.09 290.64 318.39 212.36 266.04 64 294.64 198.12 318.90 315.41 343.35 238.14 290.36 13b 1 133.48 78.37 204.17 132.94 146.03 90.93 117.30 4 241.39 155.24 312.20 265.77 293.79 180.91 234.62 16 245.51 158.43 316.91 269.22 297.37 184.73 238.52 64 285.86 180.74 337.22 289.07 317.10 204.47 257.47 70b 1 - 356.51 992.45 662.55 756.93 417.54 571.21 4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46 16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78 64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96 TPOT(ms)
bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 11.70 15.81 9.32 6.33 6.47 6.34 6.59 4 12.78 17.33 10.51 7.39 7.45 7.47 7.74 16 24.62 23.34 23.64 19.59 20.40 17.86 19.37 64 71.74 57.75 82.38 70.03 74.03 62.28 69.45 13b 1 20.10 18.29 14.39 8.99 9.18 8.94 9.31 4 23.76 22.12 18.77 12.85 13.24 12.18 12.90 16 43.33 34.32 42.85 33.15 34.98 28.45 31.72 64 146.53 92.54 151.86 117.41 125.15 96.60 111.25 70b 1 - 54.27 50.79 29.29 30.20 28.88 30.48 4 - 61.20 54.16 32.06 32.84 31.22 32.98 16 - 160.92 135.19 104.32 114.00 80.27 96.51 64 - 526.42 546.04 408.20 453.59 283.48 363.82
considering this sheet,
- what is the gpu are you using?
- for the TTFT( Time To First Token), why is llama2-13b even faster than llama2-7b?
- for the TTFT( Time To First Token), is the time-data for all batch, or for single sample? for example,
245.51msforbatchsize=16, does this mean the first token takes about0.2*16=3.2seconds for llama2-13b first token (batchsize=16)?
@HandH1998
@brisker
- A100-80G.
- I think it is because that matrix multiplication of this shape in llama2-13b can achieve greater acceleration than llama2-7b.
- No, it mens that
245.51msis for all the first tokens ofbsz=16.
@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.
input length=1024 output length=128 vLLM llama-2-series
TTFT(ms)
bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 71.63 46.71 104.29 71.94 78.35 53.44 65.99 4 270.68 171.24 293.73 285.65 313.15 208.25 261.54 16 274.95 175.09 299.09 290.64 318.39 212.36 266.04 64 294.64 198.12 318.90 315.41 343.35 238.14 290.36 13b 1 133.48 78.37 204.17 132.94 146.03 90.93 117.30 4 241.39 155.24 312.20 265.77 293.79 180.91 234.62 16 245.51 158.43 316.91 269.22 297.37 184.73 238.52 64 285.86 180.74 337.22 289.07 317.10 204.47 257.47 70b 1 - 356.51 992.45 662.55 756.93 417.54 571.21 4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46 16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78 64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96 TPOT(ms)
bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 11.70 15.81 9.32 6.33 6.47 6.34 6.59 4 12.78 17.33 10.51 7.39 7.45 7.47 7.74 16 24.62 23.34 23.64 19.59 20.40 17.86 19.37 64 71.74 57.75 82.38 70.03 74.03 62.28 69.45 13b 1 20.10 18.29 14.39 8.99 9.18 8.94 9.31 4 23.76 22.12 18.77 12.85 13.24 12.18 12.90 16 43.33 34.32 42.85 33.15 34.98 28.45 31.72 64 146.53 92.54 151.86 117.41 125.15 96.60 111.25 70b 1 - 54.27 50.79 29.29 30.20 28.88 30.48 4 - 61.20 54.16 32.06 32.84 31.22 32.98 16 - 160.92 135.19 104.32 114.00 80.27 96.51 64 - 526.42 546.04 408.20 453.59 283.48 363.82
@HandH1998 I just confused why w4a8 is faster than w8a8 on 70B model? It seem that cannot meet the theoretic roofline model, the figure in Qserve... I think at bs=64, it still fall into the memory bound, meanwhile it will OOM soon.