Not able to export ONNX from llama2-7b 8 bit quantized pytorch model
it is a torch.onnx.export issue. Filed here to track it. TBD: File on PyTorch.
STR:
Set in your shell:
export HF_HOME=
Save following as model.py:
import sys, argparse
import torch
import torch.nn as nn
import transformers
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GPTQConfig,
)
modelname = "TheBloke/Llama-2-7B-GPTQ"
kwargs = {
"torch_dtype": torch.float32,
"trust_remote_code": True,
}
quantization_config = GPTQConfig(bits=8, disable_exllama=True)
kwargs["quantization_config"] = quantization_config
kwargs["device_map"] = "cpu"
model = AutoModelForCausalLM.from_pretrained(
modelname, low_cpu_mem_usage=True, attn_implementation="eager", **kwargs
)
# model.output_hidden_states = False
tokenizer = AutoTokenizer.from_pretrained(modelname)
prompt = "What is nature of our existence?"
encoding = tokenizer(prompt, return_tensors="pt")
test_input = encoding["input_ids"].cpu()
test_output = model.generate(
test_input,
do_sample=True,
top_k=50,
max_length=100,
top_p=0.95,
temperature=1.0,
)[0]
# forward_out = model.forward(test_input)
print("Prompt:", prompt)
print("Response:", tokenizer.decode(test_output))
print("Input:", test_input)
print("Output:", test_output)
onnx_name = "model.onnx"
onnx_program = torch.onnx.export(model, test_input, onnx_name, export_params=True)
Run:
python ./model.py
Error I get:
Traceback (most recent call last):
File "/proj/gdba/kumar/nod/SHARK-TestSuite/e2eshark/test-run-fp32-onnx/pytorch/models/llama2-7b-GPTQ/runmodel.py", line 124, in <module>
onnx_program = torch.onnx.export(model, test_input, onnx_name, export_params=True)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch/onnx/utils.py", line 1613, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch/onnx/utils.py", line 1139, in _model_to_graph
graph = _optimize_graph(
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch/onnx/utils.py", line 677, in _optimize_graph
graph = _C._jit_pass_onnx(graph, operator_export_type)
File "/proj/gdba/kumar/anaconda3/envs/e2e/lib/python3.10/site-packages/torch/onnx/utils.py", line 1967, in _run_symbolic_function
raise errors.UnsupportedOperatorError(
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::bitwise_right_shift' to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
I tried different opsets, but all failed. Looks like we should report to PyTorch?
@daveliddell to file issue on PyTorch and drive its resolution and add link to PyTorch issue here
It seems the op is simply not yet supported for export to onnx, as documented in the pytorch docs here. I will try to file a request to add the support, indicating the importance to llama as motivation for them to do it. :-)
Added pytorch issue https://github.com/pytorch/pytorch/issues/119621