optimum icon indicating copy to clipboard operation
optimum copied to clipboard

Onnx granite

Open gabe-l-hart opened this issue 4 months ago • 8 comments

What does this PR do?

This PR adds support for models using IBM's GraniteForCausalLM architecture when converting to ONNX. The key changes are:

  • ~~Allow users to opt into using transformers>=4.45 for onnx conversions~~ No longer needed
  • Add "granite" to model configs and tasks
  • Add "granite" as a model_type that uses grouped attention

NOTE: I encountered an issue very similar to the one discussed in https://github.com/huggingface/optimum/issues/1835. The root cause for me was the need to add "granite" to the list of models requiring Grouped Query Attention in modeling_decoder.py. I don't believe that is the root cause for #1835 since "llama" is already present there, but it is likely a similar issue showing up in the inference module using num_attention_heads instead of num_key_value_heads.

Rationale

This PR specifically addresses the "GraniteForCausalLM" architecture for IBM's forthcoming Granite family of models. The current ibm/PowerLM-3b model use this architecture and can be used as a placeholder for testing until the new models are released. The one exception is that the PowerLM model has num_attention_heads and num_key_value_heads set to match (no Grouped Query Attention) whereas the new models will use that (thus the need for the change to ensure GQA is used for "granite" at inference time).

Testing

When testing locally, I had the following dependency versions:

onnx==1.16.2
onnxruntime==1.19.2
torch==2.4.1
torchvision==0.19.1
transformers==4.45.2

To test the conversion, I did the following:

optimum-cli export onnx \
  --model $HOME/models/powerlm-3b \
  $HOME/models/powerlm-3b-onnx \
  --task text-generation-with-past

To evaluate the output side-by-side with the source model, I used the following script:

side_by_side.py
"""
Simple function to run and time pre and post optimized models
"""

# Standard
from datetime import timedelta
import argparse
import os
import time

# Third Party
from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

def maybe_to_device(inputs: dict[str, torch.Tensor], device: str | None):
    """Send inputs to the device if desired"""
    if device:
        for k, v in inputs.items():
            inputs[k] = v.to(device)

def run_and_time(
    label: str,
    model_path: str,
    model_class: ORTModelForCausalLM | AutoModelForCausalLM,
    prompt: str,
    device: str | None,
    **kwargs,
):
    start_time = time.time()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = model_class.from_pretrained(model_path, device_map=device)
    load_end_time = time.time()
    inputs = tokenizer(prompt, return_tensors="pt")

    #DEBUG
    breakpoint()

    tok_end_time = time.time()
    maybe_to_device(inputs, device)
    outputs = model.generate(**inputs, **kwargs)
    gen_end_time = time.time()
    res = tokenizer.decode(outputs[0])
    end_time = time.time()
    print(f"------ {label} ------")
    print(res)
    print(f"Total Time: {timedelta(seconds=end_time-start_time)}")
    print(f"Load Time: {timedelta(seconds=load_end_time-start_time)}")
    print(f"Generate Time: {timedelta(seconds=gen_end_time-tok_end_time)}")
    print()

# Defaults
home = os.getenv("HOME")
assert home, "Need $HOME!"
orig_model_path = f"{home}/models/PowerLM-3b"
onnx_model_path = f"{home}/models/PowerLM-3b-onnx-O4"
prompt = "Write a code to find the maximum value in a list of numbers."
device = "cuda" if torch.cuda.is_available() else None

def main():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--prompt", "-p", default=prompt)
    parser.add_argument("--raw-model", "-r", type=str, default=None)
    parser.add_argument("--onnx-model", "-o", type=str, default=None)
    parser.add_argument("--device", "-d", type=str, default=device)
    args = parser.parse_args()
    if args.raw_model:
        run_and_time("Transformers", args.raw_model, AutoModelForCausalLM, args.prompt, args.device, max_new_tokens=100)
    if args.onnx_model:
        run_and_time("ONNX", args.onnx_model, ORTModelForCausalLM, args.prompt, args.device, max_new_tokens=100)


if __name__ == "__main__":
    main()

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] (N/A) Did you make sure to update the documentation with your changes?
    • This is a model addition and there is not model-specific documentation that I can find
  • [x] (N/A) Did you write any new necessary tests?
    • This is a model addition and there are not model-specific tests that I can find

Who can review?

  • ONNX / ONNX Runtime : @fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun

gabe-l-hart avatar Oct 08 '24 20:10 gabe-l-hart