QQQ icon indicating copy to clipboard operation
QQQ copied to clipboard

Condition to achieve linear speedup?

Open jiwonsong-dev opened this issue 1 year ago • 18 comments

I tested latency of QuantLinear forward with various sizes of input and feature sizes. But for token counts from 1 to 1024, I cannot see any speedup compared to AWQ W4A16 kernel and the results were suboptimal to pytorch FP16 Linear in most cases. I tested weight sizes (4096, 4096), (5120, 5120), (6656, 6656), (8192, 8192) which match linear sizes of LLaMA model family on A6000 and RTX3090 GPU. I see the experiments in the paper was taken on A100 GPU. Is there any specific setting or condition to see the speedup aligns with the results on paper?

jiwonsong-dev avatar Sep 12 '24 12:09 jiwonsong-dev

Overhead of activation quantization using simple PyTorch operation is substantial but the kernel itself is slower than nn.Linear for most cases.

jiwonsong-dev avatar Sep 13 '24 02:09 jiwonsong-dev

@jiwonsong-dev There is online activation quantization using simple PyTorch in QuantLinear, which is very slow. The GEMM speedup in our paper is evaluated without activation quantization. If you want to reproduce the speedup, please refer to https://github.com/HandH1998/QQQ/issues/2#issuecomment-2179921604. By the way, the activation quantizaiton is fused into element-wise kernel like rmsnorm in our vllm PR, and it will not affect the inference speed much.

HandH1998 avatar Sep 13 '24 02:09 HandH1998

Is the kernel integrated to vLLM is the same one in the repo? I see the QuantLinear slower than nn.Linear for M from 1 to 1024 when N,K are fixed to 4096 even with the quantization overhead not considered.

jiwonsong-dev avatar Sep 13 '24 06:09 jiwonsong-dev

@jiwonsong-dev The kernel is the same with that in vLLM. If there is no other operations like dtype conversion and reshape in your modified QuantLinear, the QuantLinear should deliver the similar performance with directly using the gemm kernel. Generally, the QuantLinear is only used for the simple inference in our repo. I recommend you to try vLLM for practical inference.

HandH1998 avatar Sep 13 '24 06:09 HandH1998

I checked your fork of Marlin repository and saw actual speedup via benchmark codes. Thank you for kind response!

jiwonsong-dev avatar Sep 14 '24 04:09 jiwonsong-dev

Is there any specific reason why permutation is different when packing channel quantized weights? Per group follows original Marlin format.

jiwonsong-dev avatar Sep 19 '24 12:09 jiwonsong-dev

@jiwonsong-dev It is relevant with the mma instruction's requirements https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type.

HandH1998 avatar Sep 24 '24 02:09 HandH1998

@HandH1998 I have tried QQQ-w4a8-no-group version on internVL-20B on my own task, the embarrassing thing is that, compared to w8a8, the w4a8 is faster on decoding speed indeed as expected, but slower on first-token generation. But due to the tradeoff between first token and decoding, the final speed of w4a8 is even slightly slower than w8a8.

The puzzle from me is that, I am already using the w4a8-per-channel version, with no group, why is w4a8 first-token is still such slow? According to your paper, the decoding process from w4 into w8 to do w8a8 gemm, is simply multiply 16, which should not be so slow.

Have you ever analyzed details like this for your w4a8-no-group kernel? Any further advice to optimize the kernel?

brisker avatar Sep 29 '24 07:09 brisker

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024 output length=128 vLLM llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 71.63 46.71 104.29 71.94 78.35 53.44 65.99
4 270.68 171.24 293.73 285.65 313.15 208.25 261.54
16 274.95 175.09 299.09 290.64 318.39 212.36 266.04
64 294.64 198.12 318.90 315.41 343.35 238.14 290.36
13b
1 133.48 78.37 204.17 132.94 146.03 90.93 117.30
4 241.39 155.24 312.20 265.77 293.79 180.91 234.62
16 245.51 158.43 316.91 269.22 297.37 184.73 238.52
64 285.86 180.74 337.22 289.07 317.10 204.47 257.47
70b
1 - 356.51 992.45 662.55 756.93 417.54 571.21
4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46
16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78
64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96

TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128
7b
1 11.70 15.81 9.32 6.33 6.47 6.34 6.59
4 12.78 17.33 10.51 7.39 7.45 7.47 7.74
16 24.62 23.34 23.64 19.59 20.40 17.86 19.37
64 71.74 57.75 82.38 70.03 74.03 62.28 69.45
13b
1 20.10 18.29 14.39 8.99 9.18 8.94 9.31
4 23.76 22.12 18.77 12.85 13.24 12.18 12.90
16 43.33 34.32 42.85 33.15 34.98 28.45 31.72
64 146.53 92.54 151.86 117.41 125.15 96.60 111.25
70b
1 - 54.27 50.79 29.29 30.20 28.88 30.48
4 - 61.20 54.16 32.06 32.84 31.22 32.98
16 - 160.92 135.19 104.32 114.00 80.27 96.51
64 - 526.42 546.04 408.20 453.59 283.48 363.82

HandH1998 avatar Sep 29 '24 08:09 HandH1998

@HandH1998 What does TTFT(ms) and TPOT(ms) actually mean in your chart?

brisker avatar Sep 29 '24 08:09 brisker

TTFT: Time To First Token TPOT: Time Per decoding Output Token

HandH1998 avatar Sep 29 '24 08:09 HandH1998

@HandH1998 For sq-w8a8 in your chart, which specific kernel are you refering? In my experiments, I used the official w8a8 kernel from vLLM(cutlass backend).

brisker avatar Sep 29 '24 08:09 brisker

cublas w8a8 gemm from https://github.com/vllm-project/vllm/pull/1508. But cublas and cutlass should have similar performance.

HandH1998 avatar Sep 29 '24 08:09 HandH1998

TPOT: Time Per decoding Output Token

TPOT has already includes the first decoding time? or you have excluded first token time away?

brisker avatar Sep 29 '24 08:09 brisker

It doesn't include the first token.

HandH1998 avatar Sep 29 '24 08:09 HandH1998

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024 output length=128 vLLM llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 71.63 46.71 104.29 71.94 78.35 53.44 65.99 4 270.68 171.24 293.73 285.65 313.15 208.25 261.54 16 274.95 175.09 299.09 290.64 318.39 212.36 266.04 64 294.64 198.12 318.90 315.41 343.35 238.14 290.36 13b 1 133.48 78.37 204.17 132.94 146.03 90.93 117.30 4 241.39 155.24 312.20 265.77 293.79 180.91 234.62 16 245.51 158.43 316.91 269.22 297.37 184.73 238.52 64 285.86 180.74 337.22 289.07 317.10 204.47 257.47 70b 1 - 356.51 992.45 662.55 756.93 417.54 571.21 4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46 16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78 64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96 TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 11.70 15.81 9.32 6.33 6.47 6.34 6.59 4 12.78 17.33 10.51 7.39 7.45 7.47 7.74 16 24.62 23.34 23.64 19.59 20.40 17.86 19.37 64 71.74 57.75 82.38 70.03 74.03 62.28 69.45 13b 1 20.10 18.29 14.39 8.99 9.18 8.94 9.31 4 23.76 22.12 18.77 12.85 13.24 12.18 12.90 16 43.33 34.32 42.85 33.15 34.98 28.45 31.72 64 146.53 92.54 151.86 117.41 125.15 96.60 111.25 70b 1 - 54.27 50.79 29.29 30.20 28.88 30.48 4 - 61.20 54.16 32.06 32.84 31.22 32.98 16 - 160.92 135.19 104.32 114.00 80.27 96.51 64 - 526.42 546.04 408.20 453.59 283.48 363.82

considering this sheet,

  1. what is the gpu are you using?
  2. for the TTFT( Time To First Token), why is llama2-13b even faster than llama2-7b?
  3. for the TTFT( Time To First Token), is the time-data for all batch, or for single sample? for example, 245.51ms for batchsize=16, does this mean the first token takes about 0.2*16=3.2 seconds for llama2-13b first token (batchsize=16)?

@HandH1998

brisker avatar Oct 09 '24 06:10 brisker

@brisker

  1. A100-80G.
  2. I think it is because that matrix multiplication of this shape in llama2-13b can achieve greater acceleration than llama2-7b.
  3. No, it mens that 245.51ms is for all the first tokens of bsz=16.

HandH1998 avatar Oct 11 '24 03:10 HandH1998

@brisker It it normal that w4a8 first-token is slower than w8a8, since the additional dequant operation (on slower cuda core) of w4a8 slows down tha main loop, even though the dequant overhead is small. In my experiments, if your case has a couple of decoding iteration, the final w4a8 speed is always falser than w8a8 for better decoding speed. Here we provide the detailed results.

input length=1024 output length=128 vLLM llama-2-series

TTFT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 71.63 46.71 104.29 71.94 78.35 53.44 65.99 4 270.68 171.24 293.73 285.65 313.15 208.25 261.54 16 274.95 175.09 299.09 290.64 318.39 212.36 266.04 64 294.64 198.12 318.90 315.41 343.35 238.14 290.36 13b 1 133.48 78.37 204.17 132.94 146.03 90.93 117.30 4 241.39 155.24 312.20 265.77 293.79 180.91 234.62 16 245.51 158.43 316.91 269.22 297.37 184.73 238.52 64 285.86 180.74 337.22 289.07 317.10 204.47 257.47 70b 1 - 356.51 992.45 662.55 756.93 417.54 571.21 4 - 1400.66 2766.41 2627.26 3010.08 1674.70 2292.46 16 - 1402.62 2775.86 2635.86 3016.53 1682.51 2296.78 64 - 1425.73 2807.03 2661.51 3023.35 1712.25 2326.96 TPOT(ms)

bsz fp16 sq-w8a8 awq-g128 marlin marlin-g128 qqq qqq-g128 7b 1 11.70 15.81 9.32 6.33 6.47 6.34 6.59 4 12.78 17.33 10.51 7.39 7.45 7.47 7.74 16 24.62 23.34 23.64 19.59 20.40 17.86 19.37 64 71.74 57.75 82.38 70.03 74.03 62.28 69.45 13b 1 20.10 18.29 14.39 8.99 9.18 8.94 9.31 4 23.76 22.12 18.77 12.85 13.24 12.18 12.90 16 43.33 34.32 42.85 33.15 34.98 28.45 31.72 64 146.53 92.54 151.86 117.41 125.15 96.60 111.25 70b 1 - 54.27 50.79 29.29 30.20 28.88 30.48 4 - 61.20 54.16 32.06 32.84 31.22 32.98 16 - 160.92 135.19 104.32 114.00 80.27 96.51 64 - 526.42 546.04 408.20 453.59 283.48 363.82

@HandH1998 I just confused why w4a8 is faster than w8a8 on 70B model? It seem that cannot meet the theoretic roofline model, the figure in Qserve... I think at bs=64, it still fall into the memory bound, meanwhile it will OOM soon.

Andy0422 avatar Oct 31 '24 07:10 Andy0422