onnxruntime
onnxruntime copied to clipboard
fix: BART attention fusion for key with bias🐛
Description
With #24857 attention fusion for Whisper (and BART) was revamped. :100: This PR extends the previous pr and adds support for attention fusion for BART encoders with keys + bias term.
Minimum reproducable example:
(onnxruntime) markusbilz@Markuss-Mini git % uv pip show transformers
Using Python 3.11.10 environment at onnxruntime/.venv
Name: transformers
Version: 4.52.4
import os
import numpy as np
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import onnxruntime as ort
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
class EncoderWrapper(torch.nn.Module):
"""A wrapper around the BART encoder for onnx export."""
def __init__(self, encoder: torch.nn.Module):
super().__init__()
self.encoder = encoder
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
outs = self.encoder(input_ids, attention_mask)
return outs["last_hidden_state"]
model = EncoderWrapper(encoder=model.model.encoder)
print(model)
text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
input_names = ["input_ids"]
output_names = ["encoder_output"]
onnx_path = "bart_model.onnx"
print(model)
torch.onnx.export(
model,
(input_ids,),
onnx_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "sequence_length"},
"encoder_output": {0: "batch_size", 1: "sequence_length"},
},
opset_version=20,
)
print(f"BART encoder exported to {onnx_path}")
optimization_options = FusionOptions("bart")
optimization_options.enable_attention = True
m = optimizer.optimize_model(
onnx_path,
model_type="bart",
num_heads=0,
hidden_size=0,
opt_level=2,
use_gpu=False,
verbose=True,
optimization_options=optimization_options,
only_onnxruntime=False,
)
optimized_path = "bart_encoder_optimized.onnx"
m.save_model_to_file(optimized_path)
print(f"Optimized ONNX model saved to {optimized_path}")
print(m.get_fused_operator_statistics())
sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()})
sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"])
encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()})
abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0]))
print("abs_difference", abs_diff)
Output after PR:
Please specify parameters of num_heads and hidden_size for model_type bart
Optimized ONNX model saved to bart_encoder_optimized.onnx
{'EmbedLayerNormalization': 1, 'Attention': 2, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 2, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 4, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
abs_difference 2.3841858e-07
Motivation and Context
Extends #24857. Closes #23864.
@kunal-vaishnavi @justinchuby Could you please review? I'd also like to add a test case. Could you provide some guidance where it should go? Add modelling code to onnxruntime/test/python/transformers/test_bart.py? Any feedback is greatly appreciated.