ipex-llm
ipex-llm copied to clipboard
xetla mmint4
Description
Use experimental mm_int4 implementation based on xetla to enhance batch inference
1. Why the change?
Enhance performance for batch inference
2. User API changes
Added enable_xetla to from_pretrained, currently only support sym_int4.
model = AutoModelForCausalLM.from_pretrained(model_path,
load_in_low_bit='sym_int4',
trust_remote_code=True,
use_cache=True,
torch_dtype=torch.float16,
enable_xetla=False,
)
3. Summary of the change
- added
enable_xetlatofrom_pretrained - if
enable_xetlais set toTrue, whenmodel.to("xpu")is called, functionq4_0_xpu_transposewill be applied to allLowBitLinearweights. - if
enable_xetlais set toTrue,LowBitLinearwill uselinear_q4_0.mm_int4to compute result.
4. How to test?
manually tested on llama-7b and llama-13b
5. New dependencies
xetla is introduced in llm.cpp
Please resolve the conflicts.
I add enable_xetla=True to all-in-one's run.py to test this.
- After the generation, we will call a
to('cpu')to release the memory on xpu, we will get below exception:
2024-03-28 00:32:57,238 - ERROR -
****************************Usage Error************************
xetla is not supported on CPUs but got enable_xetla=True
2024-03-28 00:32:57,238 - ERROR -
****************************Call Stack*************************
Traceback (most recent call last):
File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 1710, in <module>
run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 90, in run_model
result = run_transformer_int4_fp16_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 1024, in run_transformer_int4_fp16_gpu_win
model.to('cpu')
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1900, in to
return super().to(*args, **kwargs)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1160, in to
return self._apply(convert)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
module._apply(fn)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
module._apply(fn)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
module._apply(fn)
[Previous line repeated 2 more times]
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 833, in _apply
param_applied = fn(param)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1158, in convert
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/ipex_llm/transformers/low_bit_linear.py", line 398, in to
invalidInputError(False,
File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/ipex_llm/utils/common/log4Error.py", line 32, in invalidInputError
raise RuntimeError(errMsg)
RuntimeError: xetla is not supported on CPUs but got enable_xetla=True
- If the model's qtype is fp4, the output are all
<unk>. We should throw an error. - If I set both llama-13b and llama-7b model in config.yaml, the second model llama-7b's output are all
<unk>. But if I put llama-7b first, the output is fine.