Support pytorch engine kv int4/int8 quantization
Only update internlm and llama model. After #2104, all the models should be updated.
Benchmark benchmark/profile_throughput.py on Meta-Llama-3.1-8B-Instruct for pytorch engine with triton 2.3.0.
| quant-policy | fp16 | int8 | int4 |
|---|---|---|---|
| TTFT(s) | 2.354 | 2.035 | 2.569 |
| RPS(req/s) | 16.767 | 19.179 | 16.810 |
Tested gsm8k accuracy:
| dataset | version | metric | mode | llama3-chat-8b | llama3-chat-8b-kv8 | llama3-chat-8b-kv4 |
|---|---|---|---|---|---|---|
| gsm8k | 1d7fe4 | accuracy | gen | 77.41 | 77.56 | 73.24 |
May resolve the conflicts
Since kv int4 requires triton>=2.3.0, It would be cool if we add a check in engine. https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/pytorch/check_env/init.py
quant_policy might not be a good argument name since user might misunderstand this as online quant(gemm)
quant_policymight not be a good argument name since user might misunderstand this as online quant(gemm)
Yes, but it is the same name as turbomind.
@AllentDan Can you update support models? I will add testcase according to this. https://github.com/InternLM/lmdeploy/blob/main/docs/en/supported_models/supported_models.md
I did not test all the models since some models may fail when quant_policy=4. In my tested models, InternLM/internlm2-chat-1_8b, baichuan2/Baichuan2-13B-Chat,Meta-Llama-3.1-8B-Instruct worked while Qwen/Qwen2-1.5B-Instruct failed.
All models supported by pytorch backend and 4bits are tested. Find following errors.
- deepseek-ai/DeepSeek-V2-Lite-Chat is not support both on kvint4 and kvint8 config and error is
pipe = pipeline("/nvme/qa_test_models/deepseek-ai/DeepSeek-V2-Lite-Chat", backend_config=engine_config)
res = pipe("Hi, pls introduce shanghai")
2024-10-10 17:15:27,701 - lmdeploy - [31mERROR[0m - request.py:21 - Engine loop failed with error:
Traceback (most recent call last):
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
task.result()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 947, in async_loop
await self._async_loop()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 941, in _async_loop
await __step()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 929, in __step
raise e
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 923, in __step
raise out
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 857, in _async_loop_background
await self._async_step_background(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 739, in _async_step_background
output = await self._async_model_forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/utils.py", line 239, in __tmp
return (await func(*args, **kwargs))
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
ret = await __forward(inputs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
return await self.model_agent.async_forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 303, in async_forward
output = self._forward_impl(inputs,
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 270, in _forward_impl
output = model_forward(
File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 153, in model_forward
output = model(**input_dict)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
runner.capture(**kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
self.model(**padded_kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 636, in forward
hidden_states = self.model(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 591, in forward
hidden_states, residual = decoder_layer(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 491, in forward
hidden_states = self.self_attn(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/deepseek_v2.py", line 265, in forward
attn_output = self.attn_fwd(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/nn/attention.py", line 67, in forward
return self.impl.forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/attention.py", line 109, in forward
self.paged_attention_fwd(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 1064, in paged_attention_fwd
assert Lq == Lk * 2 and Lv * 2 == o.shape[-1]
AssertionError
- some models' response is no sense on kvint4 such as
model list is:
- microsoft/Phi-3-mini-4k-instruct-inner-4bits
- microsoft/Phi-3-mini-4k-instruct
- microsoft/Phi-3-vision-128k-instruct
- OpenGVLab/InternVL2-4B
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
- openbmb/MiniCPM-V-2_6
- 【already fixed】Qwen/Qwen2-VL-2B-Instruct and Qwen/Qwen2-VL-7B-Instruct is not support, config and error is
engine_config = PytorchEngineConfig(dtype='auto', tp=1, session_len=None, max_batch_size=None, cache_max_entry_count=0.8, prefill_interval=16, block_size=64, num_cpu_blocks=0, num_gpu_blocks=0, adapters=None, max_prefill_token_num=4096, thread_safe=False, enable_prefix_caching=False, device_type='cuda', eager_mode=False, custom_module_map=None, download_dir=None, revision=None, quant_policy=8)
pipe = pipeline("/nvme/qa_test_models/Qwen/Qwen2-VL-2B-Instruct", backend_config=engine_config)
res = pipe("Hi, pls introduce shanghai")
2024-10-10 17:17:41,974 - lmdeploy - [31mERROR[0m - request.py:21 - Engine loop failed with error: 'NoneType' object has no attribute 'stride'
Traceback (most recent call last):
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/request.py", line 17, in _raise_exception_on_finish
task.result()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 947, in async_loop
await self._async_loop()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 941, in _async_loop
await __step()
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 929, in __step
raise e
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 923, in __step
raise out
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 857, in _async_loop_background
await self._async_step_background(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 739, in _async_step_background
output = await self._async_model_forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/utils.py", line 239, in __tmp
return (await func(*args, **kwargs))
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 630, in _async_model_forward
ret = await __forward(inputs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/engine.py", line 608, in __forward
return await self.model_agent.async_forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 303, in async_forward
output = self._forward_impl(inputs,
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 270, in _forward_impl
output = model_forward(
File "/opt/py3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/engine/model_agent.py", line 153, in model_forward
output = model(**input_dict)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 160, in __call__
runner.capture(**kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/graph_runner.py", line 77, in capture
self.model(**padded_kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 379, in forward
hidden_states = self.model(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 318, in forward
hidden_states, residual = decoder_layer(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 226, in forward
hidden_states = self.self_attn(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/models/qwen2_vl.py", line 121, in forward
attn_output = self.attn_fwd(
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/nn/attention.py", line 67, in forward
return self.impl.forward(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/backends/cuda/attention.py", line 86, in forward
self.fill_kv_cache(
File "/opt/py3/lib/python3.10/site-packages/lmdeploy/pytorch/kernels/cuda/fill_kv_cache.py", line 461, in fill_kv_cache
stride_kszn=k_scales_zeros.stride(0),
AttributeError: 'NoneType' object has no attribute 'stride'
@AllentDan qwen2-vl-2b and 7b is passed on kvint8.