vllm icon indicating copy to clipboard operation
vllm copied to clipboard

Support int8 KVCacheQuant and W8A8 inference in vllm

Open AniZpZ opened this issue 1 year ago • 69 comments

We have recently implemented and tested int8 KV-Cache quantization and W8A8 inference in vLLM. We find out that our quantization implementation can increase the throughput over 20% and reduce the first token latency under heavy load. In contrast, W4A16 quant methods(eg. AWQ-based method) provided in vllm cannot improve throughput according to pr1032 because it can not benefit from int8 tensor core. So we propose this PR as an alternative quantization method.

Updates!!! We have made some more progress in https://github.com/vllm-project/vllm/pull/1112#issuecomment-1770889034

More Updates!!! If you want to properly eval mmlu dataset with vllm, some modify on sampler must be done. The code can be found in our mmlu_eval branch

Important message!!! We spilt the pr into two parts for easier review and use. The w8a8 inference part is in https://github.com/vllm-project/vllm/pull/1508 and the kv cache quant part is in https://github.com/vllm-project/vllm/pull/1507

What we have right now:

  1. int8 KV-Cache quantization related works: a. Quant\Dequant helper functions adapted from Faster Transformer b. Quantized version CUDA kernels c. Unit tests for the added kernels
  2. W8A8 inference related works: a. Int8 Gemm kernels adapted from torch-int b. W8A8 linear layer modules c. Support W8A8 inference on Llama model
  3. Test result based on our own dataset

What we plan to do:

  • [x] 1. Further kernel fusion
  • [x] 2. Code refactoring and cleaning
  • [x] 3. Opimize int8 GEMM kernel
  • [x] 4. Release SmoothQuant for LLaMA
  • [x] 5. Add code for generating KV-Cache quantization parameters (scales and zero points)
  • [x] 6. Experiments on more datasets

How to test throughput A. how to enable w8a8 inference ~~0. install cutlass because we currently use cutlass gemm kernel. We plan to replace them with cublas gemm kernel soon.~~ We support cublas gemm kernel now, you can remove cutlass gemm kernel in setup.py

  1. install smoothquant and torch-int for llama. Use "examples/generate_act_scales.py" to generate act scale, and then use "examples/export_int8_llama.py" to export int8 model. Please note to check and change the 'architectures' field in the model's config.json from 'Int8LlamaForCausalLM' to 'LlamaForCausalLM'.
  2. update vllm and execute
python ./vllm/entrypoints/api_server.py --model=/path/to/quantized/model --tokenizer=/path/to/tokenizer --max-num-batched-tokens=70000 --block-size=16 --swap-space=20 --quantization smoothquant

B. how to enable kv cache quant

  1. use vllm/kv_quant/calibrate.py to genearte scales and use vllm/kv_quant/export_kv_params.py to export kv caches.
  2. exeute
python ./vllm/entrypoints/api_server.py --model=/path/to/quantized/model --tokenizer=/path/to/tokenizer --max-num-batched-tokens=70000 --block-size=16 --swap-space=20 --kv-cache-dtype=int8 --kv-quant-params-path=/path/to/kv_params_dir

And you can use kv cache quant and w8a8 inference together

Experiment Result current test result in our datasets on A100 80G (updates with quant&rms fusion and gemm d2h bug fix)

Throughput of FP16 LLaMA-13B:

Throughput:  4.9945 requests/s, 543.0660 token/s
Average latency: 31.7689 s

Throughput of Int8 LLaMA-13B with int8 KVCacheQuant:

Throughput: 6.1147 requests/s, 664.8646 token/s, 
Average latency: 27.4222 s

Throughput of Int8 LLaMA-13B with int8 KVCacheQuant, using cublas gemm kernel:

Throughput: 6.4723 requests/s, 703.7514 token/s, 
Average latency: 25.9912 s

How to evalute model performance We add evaluation method of quanted models, currently support mmlu datasets. You can find detail in benchmarks/benchmark_evaluation.py

python benchmark_evaluation.py --model=/path/to/quantized/model --tokenizer=/path/to/tokenizer --dev-data-path=/path/to/mmlu/dev/ --test-data-path=/path/to/mmlu/test/ --kv-cache-dtype=int8 --kv-quant-params-path=/path/to/kv_params_dir --quantization=smoothquant

Updates We have released SmoothQuant for LLaMA in https://github.com/AniZpZ/smoothquant/tree/llama-dev https://github.com/AniZpZ/torch-int/tree/llama-dev

The code for generating KV-Cache quantization parameters is ready, check vllm/kv_quant fold

We replace int8 gemm with cublas version and the increasement of throughput comes to around 30%

AniZpZ avatar Sep 20 '23 13:09 AniZpZ

This is interesting work! I was going to implement int8 in AutoAWQ with time as the authors of SmoothQuant (this PR) and AWQ are the same. My best guestimate is that single_query_cached_kv_attention_quantized_kernel is doing the heavy lifting of throughput here as it comes from FasterTransformer which is well optimized.

casper-hansen avatar Sep 20 '23 13:09 casper-hansen

I fully support this, since the 4-bit AWQ model proved to have inferior quality for my use cases. Having 8 bit weights with 8 bit activation cache would be the best of both worlds, allowing for almost no loss of quality (perplexity) while being able to run inference more efficiently. I would also keep an W8A16 mode as an option, should the precision of the activations and the KV cache would make a difference in specific use cases.

viktor-ferenczi avatar Sep 21 '23 06:09 viktor-ferenczi

Hi vLLM genius @WoosukKwon @zhuohan123

This is the latest development from our team regarding quantitative support for vllm, we have done something similar to https://github.com/vllm-project/vllm/pull/1032 before. At that time, we didn't mention pr after the benchmark results showed a drop in throughput, but later we found out that https://github.com/vllm-project/vllm/pull/1032 was merged, which is very encouraging. Therefore, we continue to do performance optimization on this basis, and send out the pr in WIP state in advance, hoping to get some comments and suggestions, and finally merge into the vllm codebase smoothly. Cheers!

zhyncs avatar Sep 21 '23 08:09 zhyncs

@AniZpZ @zhyncs This is great work! My understanding is that this PR converts FP16 -> INT8 dynamically without computing a loss function to optimize perplexity. Have you evaluated perplexity on this approach?

casper-hansen avatar Sep 21 '23 09:09 casper-hansen

@AniZpZ @zhyncs This is great work! My understanding is that this PR converts FP16 -> INT8 dynamically without computing a loss function to optimize perplexity. Have you evaluated perplexity on this approach?

We implement quantization with smoothquant method for W8A8 I will release the code later. The perplexity is identical to a standard smoothquant method if you do W8A8 inference without int8 KVCacheQuant.

Quantization details are discussed in this paper(Xiao et. al)

AniZpZ avatar Sep 21 '23 09:09 AniZpZ

@AniZpZ @zhyncs This is great work! My understanding is that this PR converts FP16 -> INT8 dynamically without computing a loss function to optimize perplexity. Have you evaluated perplexity on this approach?

We implement quantization with smoothquant method for W8A8 I will release the code later. The perplexity is identical to a standard smoothquant method if you do W8A8 inference without int8 KVCacheQuant.

Quantization details are discussed in this paper(Xiao et. al)

SmoothQuant only supports OPT models. How can we test this PR when the SmoothQuant repository does not support LLaMa models? If you implement this PR without the quantization code, you will inevitably end up with a bad perplexity if you naively use W8A8 as you have no calibration dataset.

See this image, accuracy ends up being worse than INT4 if you naively convert weights to W8A8. You need the SmoothQuant or AWQ method to convert if you want to preserve accuracy. You need a framework for this, which is why I created AutoAWQ - I will look to implement INT8 quantization using the torch-int modules and would love your help with this so we can support all models in vLLM (LLaMa, MPT, Falcon, etc.) without accuracy degradation.

image

casper-hansen avatar Sep 21 '23 09:09 casper-hansen

@AniZpZ @zhyncs This is great work! My understanding is that this PR converts FP16 -> INT8 dynamically without computing a loss function to optimize perplexity. Have you evaluated perplexity on this approach?

@AniZpZ @zhyncs This is great work! My understanding is that this PR converts FP16 -> INT8 dynamically without computing a loss function to optimize perplexity. Have you evaluated perplexity on this approach?

We implement quantization with smoothquant method for W8A8 I will release the code later. The perplexity is identical to a standard smoothquant method if you do W8A8 inference without int8 KVCacheQuant. Quantization details are discussed in this paper(Xiao et. al)

SmoothQuant only supports OPT models. How can we test this PR when the SmoothQuant repository does not support LLaMa models? If you implement this PR without the quantization code, you will inevitably end up with a bad perplexity if you naively use W8A8 as you have no calibration dataset.

See this image, accuracy ends up being worse than INT4 if you naively convert weights to W8A8. You need the SmoothQuant or AWQ method to convert if you want to preserve accuracy. You need a framework for this, which is why I created AutoAWQ - I will look to implement INT8 quantization using the torch-int modules and would love your help with this so we can support all models in vLLM (LLaMa, MPT, Falcon, etc.) without accuracy degradation.

image

We implement smoothquant for llama ourself, you can find code here: https://github.com/AniZpZ/smoothquant/tree/llama-dev and easily quantize and export model with export_int8_llama.py It should work with https://github.com/AniZpZ/torch-int/tree/llama-dev

AniZpZ avatar Sep 21 '23 10:09 AniZpZ

Hi @AniZpZ @zhyncs, thank you for your great work with this PR.

I have now had more time to explore your fast implementation and found that Nvidia only has support for INT8 for high throughput, which makes this PR achieve higher throughput than INT4 due to software capabilities.

Is your proposal to run W8A16? Your code does not have A8 implemented in the llama.py model definition.

SmoothQuant implements W8A8, but it seems silly to run A8 as there should be little benefit speed-wise. Therefore, I see this as a natural choice. I want to confirm this with you for my implementation in AutoAWQ as I want to push INT8 models out using your initial LLaMa implementation, just using the AWQ method for minimum perplexity loss.

casper-hansen avatar Sep 22 '23 20:09 casper-hansen

Hi @AniZpZ @zhyncs, thank you for your great work with this PR.

I have now had more time to explore your fast implementation and found that Nvidia only has support for INT8 for high throughput, which makes this PR achieve higher throughput than INT4 due to software capabilities.

Is your proposal to run W8A16? Your code does not have A8 implemented in the llama.py model definition.

SmoothQuant implements W8A8, but it seems silly to run A8 as there should be little benefit speed-wise. Therefore, I see this as a natural choice. I want to confirm this with you for my implementation in AutoAWQ as I want to push INT8 models out using your initial LLaMa implementation, just using the AWQ method for minimum perplexity loss.

Our proposal is to run in W8A8. If you enable smoothquant, we will replace rmsnorm and linear layer with our custom int8 rmsnorm and w8a8linears which quant activations and impelement int8 gemm. You can find the detail in w8a8linear.py If you want enable tensor core to do int8 caclulation, weights and activations should both be int8.

AniZpZ avatar Sep 23 '23 02:09 AniZpZ

I agree that the RMSNorm and Linear layers run in INT8, making it W8. Running in A8 means running your activation functions with INT8, and this is not implemented (see below). Currently, this makes it W8A16. EDIT: The modules from torch-int are named W8A8, but they do not actually run A8 for you - they run W8 on your Linear layers.

Your MLP forward runs .half(), making the precision FP16 before running the Silu activation function.

Please see my annotated comments:

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x) # <-- INT8 inference (W8), out FP32
        gate_up = gate_up.half() # <-- FP32 to FP16 conversion
        x = self.act_fn(gate_up) # <-- FP16 Silu activation function (A16)
        x, _ = self.down_proj(x)
        x = x.half()
        return x

Another note: INT8 inference is not implemented in PagedAttention. Instead, dequantization is implemented. This would likely cause a large speedup (maybe another 20-30% bump in throughput) if implemented.

casper-hansen avatar Sep 23 '23 09:09 casper-hansen

Is it possible to keep the W8A16 mode as well? (Making it optional.)

viktor-ferenczi avatar Sep 23 '23 10:09 viktor-ferenczi

@viktor-ferenczi I am much in favor of W8A16 which is currently implemented. Quantized models will be easier to create in the W8A16 format without accuracy degradation.


@AniZpZ @zhyncs I believe throughput can be increased more if you implement PagedAttention as INT8. Do you have any idea if that is possible because it could increase throughput much more?

This loop dequantizes from INT8, but you could instead apply the INT8 BMM from torch-int if adapted to work with PagedAttention. @WoosukKwon perhaps you are better qualified to answer this. https://github.com/AniZpZ/vllm/blob/kv-quant-merge/csrc/attention/attention_kernels.cu#L432

casper-hansen avatar Sep 23 '23 13:09 casper-hansen

ode

it's possible to run in W8A16 but you can not benefit from int8 tensor core in that case

AniZpZ avatar Sep 23 '23 13:09 AniZpZ

PagedAttention

We repalce the input and post layer norm of attention to I8rmsnorm and make the output of layernorm to in8 that makes qkv gemm and gateup gemm run in w8a8. And we use W8A8BFP32OFP32LinearWithSFactor in o_proj and down_proj. W8A8BFP32OFP32LinearWithSFactor quant input activations to int8 and do w8a8 gemm as well.

AniZpZ avatar Sep 23 '23 13:09 AniZpZ

@viktor-ferenczi I am much in favor of W8A16 which is currently implemented. Quantized models will be easier to create in the W8A16 format without accuracy degradation.

@AniZpZ @zhyncs I believe throughput can be increased more if you implement PagedAttention as INT8. Do you have any idea if that is possible because it could increase throughput much more?

This loop dequantizes from INT8, but you could instead apply the INT8 BMM from torch-int if adapted to work with PagedAttention. @WoosukKwon perhaps you are better qualified to answer this. https://github.com/AniZpZ/vllm/blob/kv-quant-merge/csrc/attention/attention_kernels.cu#L432

You're correct that implementing PagedAttention as INT8 could potentially improve performance. However, our profiling of the inference process assumes that the gemm computation in PagedAttention only accounts for approximately 5% of the total, which suggests that performance gains from this particular optimization might be limited. We current focus is on kernel fusion and optimizing the INT8 GEMM kernel. After finishing works mentioned above, we will try to implemented PagedAttention as INT8.

AniZpZ avatar Sep 23 '23 13:09 AniZpZ

it's possible to run in W8A16 but you can not benefit from int8 tensor core in that case

We have a misunderstanding of what W8 and A8 means here.

  • torch-int W8A8: Inputs = INT8/A8, Weights = INT8/W8 in W8A8BFP32OFP32LinearWithSFactor
  • In a normal sense, I would call this W8A16 because activation functions are running in FP16.

PagedAttention only accounts for approximately 5% of the total

Thank you for sharing those numbers with me. Your work in this PR is amazing, the performance is unmatched! I am looking forward to seeing how this develops and to support W8 quantization with AutoAWQ.

casper-hansen avatar Sep 23 '23 13:09 casper-hansen

We replaced cutlass gemm kernels adapted from torch-int with cublas gemm kernels and achieve 6% improvement on throughput.

AniZpZ avatar Sep 27 '23 09:09 AniZpZ

I have implemented SmoothQuant quantization of input activations in AutoAWQ/smoothquant. However, I am getting an error around how the model is loaded.

This is the first layer when I print the state_dict.keys() in vLLM: model.layers.0.self_attn.qkv_proj.weight. However, the llama model definition looks a bit different in general. @AniZpZ can we sync up our versions to make this run smoothly? Ideally we can quantize models using AutoAWQ and we can have TheBloke help push some INT8 models.

EDIT: Works with new PR https://github.com/AniZpZ/vllm/pull/1

casper-hansen avatar Sep 27 '23 16:09 casper-hansen

state_dict

VLLM use qkv_proj to replace q_poj,k_proj and v_proj. So there has no model.layers.0.self_attn.q_proj.weight in state_dict.keys()

AniZpZ avatar Sep 28 '23 02:09 AniZpZ

@AniZpZ Loading works for now with this PR: https://github.com/AniZpZ/vllm/pull/1. Could you please accept the PR? Then I will focus on optimizing the accuracy/perplexity of models.

Also, what is the command to generate INT8 KV-Cache?

casper-hansen avatar Sep 28 '23 13:09 casper-hansen

@AniZpZ Loading works for now with this PR: AniZpZ#1. Could you please accept the PR? Then I will focus on optimizing the accuracy/perplexity of models.

Also, what is the command to generate INT8 KV-Cache?

Thank you for you support! We do not want to change origin model load method and trying to do minimum modify. Furthermore, we are trying to support tensor parallel in w8a8 linear. Thus I am going to do some code refactor later and will not accept the pr right now.

If you want to generate INT8 KV-Cache, there are two steps, first use vllm/kv_quant/calibrate.py and a calibration dataset to do the calibration. And then use vllm/kv_quant/export_kv_params.py to do the quanted param export.

You can run following command for usage details

python vllm/kv_quant/calibrate.py --help
python vllm/kv_quant/export_kv_params.py --help

AniZpZ avatar Sep 28 '23 14:09 AniZpZ

Please look at the model loading code soon. Until then, I must use my fork to continue my development. The goal is to create W8A8 models with the KV8 cache smoothly without garbage output. Currently, I cannot get any models to generate normal output in vLLM, looking into this.

EDIT: I have optimized the scaling of the llama model using SmoothQuant method, and the results are unfortunately not good for Llama models. It suggests we need new methods for quantizing into a W8A8 format.

This result is obtained by using TheBloke/Llama-2-7b-chat-fp16 and running my SmoothQuant repository which is optimized for better perplexity.

Quantization Type Scale Perplexity % Higher than FP16
W16A16 N/A 6.365 N/A
W8A8 0.4 21.62 239.73%
W8A8 0.45 16.97 166.55%
W8A8 0.5 15.80 148.26%
W8A8 0.55 15.68 146.37%
W8A8 0.6 16.93 165.95%

casper-hansen avatar Sep 28 '23 15:09 casper-hansen

Please look at the model loading code soon. Until then, I must use my fork to continue my development. The goal is to create W8A8 models with the KV8 cache smoothly without garbage output. Currently, I cannot get any models to generate normal output in vLLM, looking into this.

EDIT: I have optimized the scaling of the llama model using SmoothQuant method, and the results are unfortunately not good for Llama models. It suggests we need new methods for quantizing into a W8A8 format.

This result is obtained by using TheBloke/Llama-2-7b-chat-fp16 and running my SmoothQuant repository which is optimized for better perplexity.

Quantization Type Scale Perplexity % Higher than FP16 W16A16 N/A 6.365 N/A W8A8 0.4 21.62 239.73% W8A8 0.45 16.97 166.55% W8A8 0.5 15.80 148.26% W8A8 0.55 15.68 146.37% W8A8 0.6 16.93 165.95%

Thank you for your experiment data, we will look into the problem. Our earlier experiments focused on accuracy in datasets like mmlu. We found there were accuracy drops, but they were within an acceptable range.

AniZpZ avatar Oct 09 '23 05:10 AniZpZ

@AniZpZ smooth quant only get high accuracy with activation per-token dynamic quantization, weight per-channel quantization.

image

MeJerry215 avatar Oct 09 '23 08:10 MeJerry215

@AniZpZ smooth quant only get high accuracy with activation per-token dynamic quantization, weight per-channel quantization.

image

Yes, activation per-token dynamic quantization, weight per-channel quantization can achieve higher accuracy than activation per-tensor static, weight tensor-wise quantization. However the difference is minor according to smoothquant paper. We will look into the perplexity problem soon.

image

AniZpZ avatar Oct 09 '23 11:10 AniZpZ

@AniZpZ the root cause is activation is still too large. you can dump the down proj's input tensor. if you disable llama model down proj quantization, accuracy will get improved a lot. if you disable output proj futher, the final accuracy will only drop about 0.2 PPL. And Opt model's activation is not such abnormal.

Here's llama activation and output MinMax.

layer: model.layers.2.self_attn.q_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.q_proj_out min: 0.69 max: 11.41
layer: model.layers.2.self_attn.k_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.k_proj_out min: 0.73 max: 16.66
layer: model.layers.2.self_attn.v_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.v_proj_out min: 0.27 max: 4.79
layer: model.layers.2.self_attn.o_proj min: 0.12 max: 4.79
layer: model.layers.2.self_attn.o_proj_out min: 0.20 max: 9.45
layer: model.layers.2.mlp.gate_proj min: 0.26 max: 13.90
layer: model.layers.2.mlp.gate_proj_out min: 0.64 max: 37.94
layer: model.layers.2.mlp.up_proj min: 0.26 max: 13.90
layer: model.layers.2.mlp.up_proj_out min: 0.68 max: 33.28
layer: model.layers.2.mlp.down_proj min: 0.19 max: 1263.00
layer: model.layers.2.mlp.down_proj_out min: 0.28 max: 2426.00

In our work, performance impoved about 30-50% in this case compared to fp16.

MeJerry215 avatar Oct 10 '23 01:10 MeJerry215

2426.00

@AniZpZ the root cause is activation is still too large. you can dump the down proj's input tensor. if you disable llama model down proj quantization, accuracy will get improved a lot. if you disable output proj futher, the final accuracy will only drop about 0.2 PPL. And Opt model's activation is not such abnormal.

Here's llama activation and output MinMax.

layer: model.layers.2.self_attn.q_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.q_proj_out min: 0.69 max: 11.41
layer: model.layers.2.self_attn.k_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.k_proj_out min: 0.73 max: 16.66
layer: model.layers.2.self_attn.v_proj min: 0.00 max: 7.50
layer: model.layers.2.self_attn.v_proj_out min: 0.27 max: 4.79
layer: model.layers.2.self_attn.o_proj min: 0.12 max: 4.79
layer: model.layers.2.self_attn.o_proj_out min: 0.20 max: 9.45
layer: model.layers.2.mlp.gate_proj min: 0.26 max: 13.90
layer: model.layers.2.mlp.gate_proj_out min: 0.64 max: 37.94
layer: model.layers.2.mlp.up_proj min: 0.26 max: 13.90
layer: model.layers.2.mlp.up_proj_out min: 0.68 max: 33.28
layer: model.layers.2.mlp.down_proj min: 0.19 max: 1263.00
layer: model.layers.2.mlp.down_proj_out min: 0.28 max: 2426.00

In our work, performance impoved about 30-50% in this case compared to fp16.

We got it. The root cause is that the smooth operation do not work on down proj and output proj for now. The silu function in mlp and the rotary embedding make us have to do dequantization after qkv proj and gate up proj, which will lead to big activations. However, the dequantization is not neccessry in opt model, activations are kept within a limited range throughout the entire MLP and attention process.

We are working on this problem now.

AniZpZ avatar Oct 10 '23 03:10 AniZpZ

I have conducted more experiments that achieve the same results as in the paper.

There is only one problem: per-channel weight quantization is not compatible with the CUDA kernels because you get multiple weight scales which is not compatible to be used for the alpha value.

Specifically, the alpha value needs to be just one number to be compatible with the current kernels that are available. I also suspect that if you use per-channel and make it work, it will decrease performance.

casper-hansen avatar Oct 11 '23 11:10 casper-hansen

I have conducted more experiments that achieve the same results as in the paper.

There is only one problem: per-channel weight quantization is not compatible with the CUDA kernels because you get multiple weight scales which is not compatible to be used for the alpha value.

Specifically, the alpha value needs to be just one number to be compatible with the current kernels that are available. I also suspect that if you use per-channel and make it work, it will decrease performance.

Yes, but it could use tensor-core if you rewrite cuda kernel. and this will half the cost of time.

here aq and bq is quantized int8 value, s and b is per-chanel scale.

A = [                = [
    a11, a12               aq11 * s1,  aq12 * s1
    a21, a22               aq21 * s2,  aq22 * s2
]                          ]
B = [                = [
    b11, b12             bq11 * b1,  bq12 * b2
    b21, b22             bq21 * b1,  bq22 * b2
]                          ]

A @ B = [
       aq11 * s1 * bq11 * b1 + aq12 * s1 * bq21 * b1,     aq11 * s1 * bq12 * b2 + aq12 * s1 * bq22 * b2
       aq21 * s2 * bq11 * b1 + aq22 * s2 * bq21 * b1,     aq21 * s2 * bq12 * b2 + aq22 * s2 * bq22 * b2
]
         = [                                                                                      [
       aq11 * bq11 + aq12 * bq21,   aq11 * bq12 + aq12 * bq22      *            s1 * b1, s1 * b2
       aq21 * bq11 + aq22 * bq21 ,  aq21 * bq12 + aq22 * bq22                    s2 * b1, s2 * b2
           ]                                                                                         ]

we rewrite the cuda kernel with per-channel cuda implementation. I8 input FP16 out. half the cost of the gemm kernel time.

MeJerry215 avatar Oct 12 '23 02:10 MeJerry215

We have made a update that make the fusion of quant/dequant kernels that furhter improve the performance. Now with w8a8 inference and kv cache quant we achive a throughput improvement of over 50% ! We fix a bug in KV cache quant and add validation on MMLU datasets.

Throughput of FP16 LLaMA-13B:

Throughput:  4.9945 requests/s, 543.0660 token/s

Throughput of W8A8 LLaMA-13B:

Throughput:  6.5197  requests/s, 708.9054  token/s

Throughput of W8A8 and KV Cache Quant LLaMA-13B:

Throughput:  7.6821  requests/s, 835.2985  token/s

And we also eval the quanted model with MMLU dataset:

model STEM Social Sciences Humanities Other Average
fp16 0.4056 0.4965 0.4750 0.4855 0.4657
fp16 with int8kv 0.4037 0.4956 0.4748 0.4849 0.4648
w8a8 with int8kv 0.2584 0.2559 0.2509 0.2589 0.2570
w8a8 partial with int8kv 0.4009 0.4796 0.4575 0.4766 0.4537

w8a8 partial with kv: We apply quantization with fp16 down proj and out proj which achieve minimal accuracy drop. In this method, we achive 706.0325 token/s that gain about 30% percent throughout improvement!

AniZpZ avatar Oct 19 '23 12:10 AniZpZ