ipex-llm icon indicating copy to clipboard operation
ipex-llm copied to clipboard

xetla mmint4

Open yangw1234 opened this issue 1 year ago • 1 comments

Description

Use experimental mm_int4 implementation based on xetla to enhance batch inference

1. Why the change?

Enhance performance for batch inference

2. User API changes

Added enable_xetla to from_pretrained, currently only support sym_int4.

    model = AutoModelForCausalLM.from_pretrained(model_path,
                                                 load_in_low_bit='sym_int4',
                                                 trust_remote_code=True,
                                                 use_cache=True,
                                                 torch_dtype=torch.float16,
                                                 enable_xetla=False,
                                                 )

3. Summary of the change

  1. added enable_xetla to from_pretrained
  2. if enable_xetla is set to True, when model.to("xpu") is called, function q4_0_xpu_transpose will be applied to all LowBitLinear weights.
  3. if enable_xetla is set to True, LowBitLinear will use linear_q4_0.mm_int4 to compute result.

4. How to test?

manually tested on llama-7b and llama-13b

5. New dependencies

xetla is introduced in llm.cpp

yangw1234 avatar Jan 30 '24 02:01 yangw1234

Please resolve the conflicts.

cyita avatar Feb 26 '24 03:02 cyita

I add enable_xetla=True to all-in-one's run.py to test this.

  1. After the generation, we will call a to('cpu') to release the memory on xpu, we will get below exception:
2024-03-28 00:32:57,238 - ERROR - 

****************************Usage Error************************
xetla is not supported on CPUs but got enable_xetla=True
2024-03-28 00:32:57,238 - ERROR - 

****************************Call Stack*************************
Traceback (most recent call last):
  File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 1710, in <module>
    run_model(model, api, in_out_pairs, conf['local_model_hub'], conf['warm_up'], conf['num_trials'], conf['num_beams'],
  File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 90, in run_model
    result = run_transformer_int4_fp16_gpu_win(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, cpu_embedding, batch_size, streaming)
  File "/home/wangruonan/xin/BigDL/python/llm/dev/benchmark/all-in-one/run.py", line 1024, in run_transformer_int4_fp16_gpu_win
    model.to('cpu')
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1900, in to
    return super().to(*args, **kwargs)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  [Previous line repeated 2 more times]
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 833, in _apply
    param_applied = fn(param)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1158, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/ipex_llm/transformers/low_bit_linear.py", line 398, in to
    invalidInputError(False,
  File "/home/wangruonan/anaconda3/envs/xin-llm/lib/python3.9/site-packages/ipex_llm/utils/common/log4Error.py", line 32, in invalidInputError
    raise RuntimeError(errMsg)
RuntimeError: xetla is not supported on CPUs but got enable_xetla=True

  1. If the model's qtype is fp4, the output are all <unk>. We should throw an error.
  2. If I set both llama-13b and llama-7b model in config.yaml, the second model llama-7b's output are all <unk>. But if I put llama-7b first, the output is fine.

qiuxin2012 avatar Mar 28 '24 07:03 qiuxin2012