onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

fix: BART attention fusion for key with bias🐛

Open KarelZe opened this issue 4 months ago • 1 comments

Description

With #24857 attention fusion for Whisper (and BART) was revamped. :100: This PR extends the previous pr and adds support for attention fusion for BART encoders with keys + bias term.

Minimum reproducable example:

(onnxruntime) markusbilz@Markuss-Mini git % uv pip show transformers 
Using Python 3.11.10 environment at onnxruntime/.venv
Name: transformers
Version: 4.52.4
import os

import numpy as np
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

import onnxruntime as ort
from onnxruntime.transformers import optimizer
from onnxruntime.transformers.fusion_options import FusionOptions

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_name = "hf-internal-testing/tiny-random-bart"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model.eval()


class EncoderWrapper(torch.nn.Module):
    """A wrapper around the BART encoder for onnx export."""

    def __init__(self, encoder: torch.nn.Module):
        super().__init__()
        self.encoder = encoder

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor:
        outs = self.encoder(input_ids, attention_mask)
        return outs["last_hidden_state"]


model = EncoderWrapper(encoder=model.model.encoder)
print(model)

text = "God bless the internet."
inputs = tokenizer(text, return_tensors="pt")

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

input_names = ["input_ids"]
output_names = ["encoder_output"]

onnx_path = "bart_model.onnx"

print(model)

torch.onnx.export(
    model,
    (input_ids,),
    onnx_path,
    export_params=True,
    input_names=input_names,
    output_names=output_names,
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "encoder_output": {0: "batch_size", 1: "sequence_length"},
    },
    opset_version=20,
)
print(f"BART encoder exported to {onnx_path}")

optimization_options = FusionOptions("bart")
optimization_options.enable_attention = True

m = optimizer.optimize_model(
    onnx_path,
    model_type="bart",
    num_heads=0,
    hidden_size=0,
    opt_level=2,
    use_gpu=False,
    verbose=True,
    optimization_options=optimization_options,
    only_onnxruntime=False,
)

optimized_path = "bart_encoder_optimized.onnx"
m.save_model_to_file(optimized_path)

print(f"Optimized ONNX model saved to {optimized_path}")
print(m.get_fused_operator_statistics())

sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
encoder_outs_original = sess.run(["encoder_output"], {"input_ids": input_ids.numpy()})

sess_optimized = ort.InferenceSession(optimized_path, providers=["CPUExecutionProvider"])
encoder_outs_optimized = sess_optimized.run(["encoder_output"], {"input_ids": input_ids.numpy()})

abs_diff = np.amax(np.abs(encoder_outs_original[0] - encoder_outs_optimized[0]))
print("abs_difference", abs_diff)

Output after PR:

Please specify parameters of num_heads and hidden_size for model_type bart
Optimized ONNX model saved to bart_encoder_optimized.onnx
{'EmbedLayerNormalization': 1, 'Attention': 2, 'MultiHeadAttention': 0, 'Gelu': 0, 'FastGelu': 0, 'BiasGelu': 2, 'GemmFastGelu': 0, 'LayerNormalization': 0, 'SimplifiedLayerNormalization': 0, 'SkipLayerNormalization': 4, 'SkipSimplifiedLayerNormalization': 0, 'RotaryEmbedding': 0, 'QOrderedAttention': 0, 'QOrderedGelu': 0, 'QOrderedLayerNormalization': 0, 'QOrderedMatMul': 0}
abs_difference 2.3841858e-07

Motivation and Context

Extends #24857. Closes #23864.

@kunal-vaishnavi @justinchuby Could you please review? I'd also like to add a test case. Could you provide some guidance where it should go? Add modelling code to onnxruntime/test/python/transformers/test_bart.py? Any feedback is greatly appreciated.

KarelZe avatar Jun 13 '25 06:06 KarelZe