optimum
optimum copied to clipboard
Onnx granite
What does this PR do?
This PR adds support for models using IBM's GraniteForCausalLM architecture when converting to ONNX. The key changes are:
- ~~Allow users to opt into using
transformers>=4.45
foronnx
conversions~~ No longer needed - Add
"granite"
to model configs and tasks - Add
"granite"
as amodel_type
that uses grouped attention
NOTE: I encountered an issue very similar to the one discussed in https://github.com/huggingface/optimum/issues/1835. The root cause for me was the need to add "granite"
to the list of models requiring Grouped Query Attention in modeling_decoder.py
. I don't believe that is the root cause for #1835 since "llama"
is already present there, but it is likely a similar issue showing up in the inference module using num_attention_heads
instead of num_key_value_heads
.
Rationale
This PR specifically addresses the "GraniteForCausalLM"
architecture for IBM's forthcoming Granite
family of models. The current ibm/PowerLM-3b model use this architecture and can be used as a placeholder for testing until the new models are released. The one exception is that the PowerLM
model has num_attention_heads
and num_key_value_heads
set to match (no Grouped Query Attention) whereas the new models will use that (thus the need for the change to ensure GQA is used for "granite"
at inference time).
Testing
When testing locally, I had the following dependency versions:
onnx==1.16.2
onnxruntime==1.19.2
torch==2.4.1
torchvision==0.19.1
transformers==4.45.2
To test the conversion, I did the following:
optimum-cli export onnx \
--model $HOME/models/powerlm-3b \
$HOME/models/powerlm-3b-onnx \
--task text-generation-with-past
To evaluate the output side-by-side with the source model, I used the following script:
side_by_side.py
"""
Simple function to run and time pre and post optimized models
"""
# Standard
from datetime import timedelta
import argparse
import os
import time
# Third Party
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
def maybe_to_device(inputs: dict[str, torch.Tensor], device: str | None):
"""Send inputs to the device if desired"""
if device:
for k, v in inputs.items():
inputs[k] = v.to(device)
def run_and_time(
label: str,
model_path: str,
model_class: ORTModelForCausalLM | AutoModelForCausalLM,
prompt: str,
device: str | None,
**kwargs,
):
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = model_class.from_pretrained(model_path, device_map=device)
load_end_time = time.time()
inputs = tokenizer(prompt, return_tensors="pt")
#DEBUG
breakpoint()
tok_end_time = time.time()
maybe_to_device(inputs, device)
outputs = model.generate(**inputs, **kwargs)
gen_end_time = time.time()
res = tokenizer.decode(outputs[0])
end_time = time.time()
print(f"------ {label} ------")
print(res)
print(f"Total Time: {timedelta(seconds=end_time-start_time)}")
print(f"Load Time: {timedelta(seconds=load_end_time-start_time)}")
print(f"Generate Time: {timedelta(seconds=gen_end_time-tok_end_time)}")
print()
# Defaults
home = os.getenv("HOME")
assert home, "Need $HOME!"
orig_model_path = f"{home}/models/PowerLM-3b"
onnx_model_path = f"{home}/models/PowerLM-3b-onnx-O4"
prompt = "Write a code to find the maximum value in a list of numbers."
device = "cuda" if torch.cuda.is_available() else None
def main():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--prompt", "-p", default=prompt)
parser.add_argument("--raw-model", "-r", type=str, default=None)
parser.add_argument("--onnx-model", "-o", type=str, default=None)
parser.add_argument("--device", "-d", type=str, default=device)
args = parser.parse_args()
if args.raw_model:
run_and_time("Transformers", args.raw_model, AutoModelForCausalLM, args.prompt, args.device, max_new_tokens=100)
if args.onnx_model:
run_and_time("ONNX", args.onnx_model, ORTModelForCausalLM, args.prompt, args.device, max_new_tokens=100)
if __name__ == "__main__":
main()
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] (N/A) Did you make sure to update the documentation with your changes?
- This is a model addition and there is not model-specific documentation that I can find
- [x] (N/A) Did you write any new necessary tests?
- This is a model addition and there are not model-specific tests that I can find
Who can review?
- ONNX / ONNX Runtime : @fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun