onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

[ONNX] Support huggingface BART to ONNX

Open titaiwangms opened this issue 2 years ago • 4 comments

Add BART into transformer support, specificalyy for BartForConditionalGeneration

Motivation and Context

  • fixes #11210

Currently, the custom op beam search is not working in nightly, this PR should be run with a custom commit

titaiwangms avatar Aug 29 '22 21:08 titaiwangms

This pull request introduces 7 alerts when merging 3ac9cdff2635f75a5e1f920fdbfb42eef4262265 into 17ccd6fa02877a1c8d3201344137b1ca105b681d - view on LGTM.com

new alerts:

  • 5 for Unused import
  • 2 for First parameter of a method is not named 'self'

lgtm-com[bot] avatar Aug 29 '22 22:08 lgtm-com[bot]

I didn't understand the goal of this PR. Why do we need this set of scripts? Is this for creating custom ops or just an example for ORT? Isnt the pytorch onnx converter capable of converting this model without that?

thiagocrepaldi avatar Sep 06 '22 15:09 thiagocrepaldi

I didn't understand the goal of this PR. Why do we need this set of scripts? Is this for creating custom ops or just an example for ORT? Isnt the pytorch onnx converter capable of converting this model without that?

Yes, a custom op is consumed. This is for paragraph summarization model: BartForConditionalGeneration. In order to convert this seq2seq model, we also need a custom op beam search. I think there are some teams (customer) are using this.

There are needs for converting transformers, and it usually gets difficult when it comes to a need of custom Ops. This PR provides BART model conversion as we have it already, and now provide in this folder as demo.

titaiwangms avatar Sep 06 '22 16:09 titaiwangms

This pull request introduces 1 alert when merging 7ddcc212080ddb83cd8025f48bcbea7f39e5f121 into 9edc9465f0258fa4cf83d7975905afa9ce1ec436 - view on LGTM.com

new alerts:

  • 1 for Unused import

lgtm-com[bot] avatar Sep 13 '22 21:09 lgtm-com[bot]

@AllenTiTaiWang, thanks for the effort.

Could you add license header to py files?

Did you able to test the accuracy of exported model (like compare the generation result with pytorch model)?

You mentioned custom commit, any idea why the code works for T5 but not Bart?

Although not required to be done in this PR, I think T5 and Bart exporters could share a lot of code in onnx exporting.

tianleiwu avatar Sep 23 '22 20:09 tianleiwu

@AllenTiTaiWang, thanks for the effort.

Could you add license header to py files?

Done

Did you able to test the accuracy of exported model (like compare the generation result with pytorch model)?

I don't have a strict metrics to test whether onnx/torch matches or not. But onnx_inference.py prints out their prediction on a given example paragraph, and they match on summarized text result.

You mentioned custom commit, any idea why the code works for T5 but not Bart?

The issue was on custom Beam Search Op. According to @wangyems, the op was changed by others PR, so although 1.12.1 works with BART, the nightly didn't. However, I am not sure if this issue has been fixed in current main branch, as @wangyems said this issue would go internal discussion in runtime team. cc @wangyems.

Although not required to be done in this PR, I think T5 and Bart exporters could share a lot of code in onnx exporting.

SGTM. I will see if there is any improvement that I can contribute more when I have time.

titaiwangms avatar Sep 26 '22 23:09 titaiwangms

@AllenTiTaiWang, thanks for the effort. Could you add license header to py files?

Done

Did you able to test the accuracy of exported model (like compare the generation result with pytorch model)?

I don't have a strict metrics to test whether onnx/torch matches or not. But onnx_inference.py prints out their prediction on a given example paragraph, and they match on summarized text result.

You mentioned custom commit, any idea why the code works for T5 but not Bart?

The issue was on custom Beam Search Op. According to @wangyems, the op was changed by others PR, so although 1.12.1 works with BART, the nightly didn't. However, I am not sure if this issue has been fixed in current main branch, as @wangyems said this issue would go internal discussion in runtime team. cc @wangyems.

Although not required to be done in this PR, I think T5 and Bart exporters could share a lot of code in onnx exporting.

SGTM. I will see if there is any improvement that I can contribute more when I have time.

the previous mismatch was caused by zcode change where input_ids are the full copy of sequences. The huggingface bart model, however, is not. The fix for this hasn't been in master yet. I'll draft a PR soon...

wangyems avatar Sep 27 '22 00:09 wangyems

@AllenTiTaiWang, thanks for the effort. Could you add license header to py files?

Done

Did you able to test the accuracy of exported model (like compare the generation result with pytorch model)?

I don't have a strict metrics to test whether onnx/torch matches or not. But onnx_inference.py prints out their prediction on a given example paragraph, and they match on summarized text result.

You mentioned custom commit, any idea why the code works for T5 but not Bart?

The issue was on custom Beam Search Op. According to @wangyems, the op was changed by others PR, so although 1.12.1 works with BART, the nightly didn't. However, I am not sure if this issue has been fixed in current main branch, as @wangyems said this issue would go internal discussion in runtime team. cc @wangyems.

Although not required to be done in this PR, I think T5 and Bart exporters could share a lot of code in onnx exporting.

SGTM. I will see if there is any improvement that I can contribute more when I have time.

the previous mismatch was caused by zcode change where input_ids are the full copy of sequences. The huggingface bart model, however, is not. The fix for this hasn't been in master yet. I'll draft a PR soon...

Thanks!

titaiwangms avatar Sep 27 '22 02:09 titaiwangms

I saw different results from PyTorch and ORT. For example, the stdout from python export.py -m facebook/bart-base. I used PyTorch 1.12.1+cu116, transformers 4.18.0 and onnxruntime-gpu 1.12.1:

pytorch inference ...
--- 0.6589293479919434 seconds ---
batch 0 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 0 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 0 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 0 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 0 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
batch 1 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 1 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 1 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 1 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 1 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
batch 2 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 2 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 2 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 2 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 2 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
ORT inference ...
--- 2.7532260417938232 seconds ---
batch 0 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 0 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 0 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 0 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 0 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were
batch 1 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 1 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 1 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 1 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 1 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were
batch 2 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 2 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 2 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 2 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 2 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were

tianleiwu avatar Sep 27 '22 20:09 tianleiwu

I saw different results from PyTorch and ORT. For example, the stdout from python export.py -m facebook/bart-base. I used PyTorch 1.12.1+cu116, transformers 4.18.0 and onnxruntime-gpu 1.12.1:

pytorch inference ...
--- 0.6589293479919434 seconds ---
batch 0 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 0 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 0 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 0 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 0 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
batch 1 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 1 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 1 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 1 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 1 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
batch 2 : sequence: 0 PG&E stated it scheduled the blackouts in response to forecasts for high winds amid
batch 2 : sequence: 1 PG&E stated it scheduled the blackouts in response to forecasts for high winds and
batch 2 : sequence: 2 PG&E stated it scheduled the blackouts in response to forecasts for high winds during
batch 2 : sequence: 3 PG&E stated it scheduled the blackouts in response to forecasts for high winds,
batch 2 : sequence: 4 PG&E stated it scheduled the blackouts in response to forecasts for high winds in
ORT inference ...
--- 2.7532260417938232 seconds ---
batch 0 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 0 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 0 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 0 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 0 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were
batch 1 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 1 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 1 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 1 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 1 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were
batch 2 : sequence: 0 PG&E stated it scheduled the blackouts. The aim is to reduce the risk
batch 2 : sequence: 1 PG&E stated it scheduled the blackouts. The aim is to reduce the
batch 2 : sequence: 2 PG&E stated it scheduled the blackouts. The aim is to reduce the risk.
batch 2 : sequence: 3 PG&E stated it scheduled the blackouts. The aim is to reduce the risk of
batch 2 : sequence: 4 PG&E stated it scheduled the blackouts. The aim is to reduce the customers were

Sorry I wasn't being clear. It works with custom commit. And 1.12.1 simply doesn't seg fault error, like nightly does.

titaiwangms avatar Sep 28 '22 22:09 titaiwangms

I tested with

python export.py -m facebook/bart-large-cnn

and get:

pytorch inference ...
params set:  {'input_ids', 'return_dict', 'attention_mask', 'use_cache', 'head_mask', 'cross_attn_head_mask', 'past_key_values', 'output_attentions', 'decoder_head_mask', 'decoder_attention_mask', 'decoder_input_ids', 'labels', 'decoder_inputs_embeds', 'encoder_outputs', 'output_hidden_states', 'inputs_embeds'}
accepts_attention_mask:  True
requires_attention_mask:  True
attention_mask:  None
model_kwargs[attention_mask] created!!!
--- 1.7520148754119873 seconds ---
batch 0 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions
batch 0 : sequence: 1 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions.
batch 0 : sequence: 2 PG&E scheduled the blackouts in response to forecasts for high winds. The aim
batch 0 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800
batch 0 : sequence: 4 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is
batch 1 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions
batch 1 : sequence: 1 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions.
batch 1 : sequence: 2 PG&E scheduled the blackouts in response to forecasts for high winds. The aim
batch 1 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800
batch 1 : sequence: 4 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is
batch 2 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions
batch 2 : sequence: 1 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions.
batch 2 : sequence: 2 PG&E scheduled the blackouts in response to forecasts for high winds. The aim
batch 2 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800
batch 2 : sequence: 4 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is
ORT inference ...
--- 1.1799135208129883 seconds ---
batch 0 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.
batch 0 : sequence: 1 PG&E scheduled the blackouts in response to forecasts for high winds. The aim is
batch 0 : sequence: 2 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is to
batch 0 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800 thousand
batch 0 : sequence: 4 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions. Nearly
batch 1 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.
batch 1 : sequence: 1 PG&E scheduled the blackouts in response to forecasts for high winds. The aim is
batch 1 : sequence: 2 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is to
batch 1 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800 thousand
batch 1 : sequence: 4 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions. Nearly
batch 2 : sequence: 0 PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions.
batch 2 : sequence: 1 PG&E scheduled the blackouts in response to forecasts for high winds. The aim is
batch 2 : sequence: 2 Nearly 800 thousand customers were scheduled to be affected by the shutoffs. The aim is to
batch 2 : sequence: 3 PG&E scheduled the blackouts in response to forecasts for high winds. Nearly 800 thousand
batch 2 : sequence: 4 PG&E scheduled blackouts in response to forecasts for high winds amid dry conditions. Nearly

as bart-large-cnn it's used as an example on huggingface-BART

quoted:

from transformers import BartTokenizer, BartForConditionalGeneration

model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")

ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="pt")

# Generate Summary
summary_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

# results
'PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions'

The results are not identical in terms of the order and sentence length (one word different), which should not be a big issue.

titaiwangms avatar Oct 05 '22 15:10 titaiwangms

Some functions are similar to T5 export script. Suggest to do some refactoring to consolidate Bart and T5 script later.

tracked #13221

titaiwangms avatar Oct 05 '22 22:10 titaiwangms

Hello everyone, I have used the branch above for facebook/bart-base pytorch model to ONNX model. I have used same optional parameters given in the readme file. The summarization results between pytorch and ONNX models is different. Is it expected. The results are below given. Please help me out.

Packages used:

  • python=3.7
  • pytorch=1.10.0
  • onnx=1.10.1
  • onnxruntime=1.12.1
  • transformers=4.23.1

input_text = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

pytorch_output= 'PG&E stated it scheduled the blackouts in response to forecasts for high winds amid' onnx_output='PG&E stated it scheduled the blackouts. The aim is to reduce the risk'

ellurunaresh avatar Oct 17 '22 12:10 ellurunaresh

Hello everyone, I have used the branch above for facebook/bart-base pytorch model to ONNX model. I have used same optional parameters given in the readme file. The summarization results between pytorch and ONNX models is different. Is it expected. The results are below given. Please help me out.

Packages used:

  • python=3.7
  • pytorch=1.10.0
  • onnx=1.10.1
  • onnxruntime=1.12.1
  • transformers=4.23.1

input_text = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."

pytorch_output= 'PG&E stated it scheduled the blackouts in response to forecasts for high winds amid' onnx_output='PG&E stated it scheduled the blackouts. The aim is to reduce the risk'

Hi @ellurunaresh Could you try with nightly ORT? If there is anything wrong please raise an issue. Thanks!

titaiwangms avatar Oct 17 '22 19:10 titaiwangms

Hello everyone, I have used the branch above for facebook/bart-base pytorch model to ONNX model. I have used same optional parameters given in the readme file. The summarization results between pytorch and ONNX models is different. Is it expected. The results are below given. Please help me out. Packages used:

  • python=3.7
  • pytorch=1.10.0
  • onnx=1.10.1
  • onnxruntime=1.12.1
  • transformers=4.23.1

input_text = "PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow." pytorch_output= 'PG&E stated it scheduled the blackouts in response to forecasts for high winds amid' onnx_output='PG&E stated it scheduled the blackouts. The aim is to reduce the risk'

Hi @ellurunaresh Could you try with nightly ORT? If there is anything wrong please raise an issue. Thanks!

An error occurred while loading the model using ort-nightly=1.11.0. The error below given. Please take a look.

onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from onnx_models/model_final.onnx failed:This is an invalid model. In Node, ("BeamSearch_zcode", BeamSearch, "com.microsoft", -1) : ("input_ids": tensor(int32),"max_length": tensor(int32),"min_length": tensor(int32),"num_beams": tensor(int32),"num_return_sequences": tensor(int32),"length_penalty": tensor(float),"repetition_penalty": tensor(float),"","","attention_mask": tensor(int32),) -> ("sequences": tensor(int32),) , Error Unrecognized attribute: encoder for operator BeamSearch

ellurunaresh avatar Oct 18 '22 04:10 ellurunaresh

Hi everyone, I tried to export a MBartForConditionalGeneration and all exporting processes worked well. However, when I use onnxruntime to load the model, it raises the following error: onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph: [ONNXRuntimeError] : 10 : INVALID_GRAPH : Load model from onnx_models/model_final.onnx failed:This is an invalid model. In Node, ("BeamSearch_zcode", BeamSearch, "com.microsoft", -1) : ("input_ids": tensor(int32),"emotion_mask": tensor(int32),"max_length": tensor(int32),"min_length": tensor(int32),"num_beams": tensor(int32),"num_return_sequences": tensor(int32),"length_penalty": tensor(float),"repetition_penalty": tensor(float),"renormalize_logits": tensor(bool),"remove_invalid_values": tensor(bool),"top_k": tensor(int32),"","","attention_mask": tensor(int32),) -> ("sequences": tensor(int32),) , Error Node (BeamSearch_zcode) has input size 14 not in range [min=5, max=10]. Here is my env info:

  • python==3.8
  • onnx==1.13.1
  • onnxruntime==1.14.0
  • pytorch==2.0.0

Quang-elec44 avatar Jun 16 '23 03:06 Quang-elec44

@Quang-elec44, your BeamSearch has 14 inputs. But the script adds 10 inputs: https://github.com/microsoft/onnxruntime/blob/1866a9d81879dd4294b4c8e105f916a4739664bc/onnxruntime/python/tools/transformers/models/bart/utils/chain_enc_dec_with_beamsearch.py#LL78C1-L87C26

How did you generate the model?

tianleiwu avatar Jun 16 '23 06:06 tianleiwu

@tianleiwu Here is my script

import os

import onnx
from onnx import TensorProto, helper
from utils import export_helper


def make_dim_proto_numeric(model, config):
    """Make dim_proto numeric.

    Args:
        model (BartForConditionalGeneration): Bart model.
        config: Bart config.
    """
    sequence_length = str(1)
    num_heads = str(config.encoder_attention_heads)
    hidden_size = str(config.d_model)
    head_size = str(config.encoder_attention_heads)

    for tensor in model.graph.output:
        for dim_proto in tensor.type.tensor_type.shape.dim:
            if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
                sequence_length,
                num_heads,
                hidden_size,
                head_size,
            ]:
                dim_value = int(dim_proto.dim_param)
                dim_proto.Clear()
                dim_proto.dim_value = dim_value

    for tensor in model.graph.input:
        for dim_proto in tensor.type.tensor_type.shape.dim:
            if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
                sequence_length,
                num_heads,
                hidden_size,
                head_size,
            ]:
                dim_value = int(dim_proto.dim_param)
                dim_proto.Clear()
                dim_proto.dim_value = dim_value


def convert_model(args):
    """Combine encoder, decoder, and beam search op to convert ONNX model.

    Using beam search op to connect encoder and decoder of the model, and convert it into one ONNX model.

    Args:
        args: User input.
    """
    config, _ = export_helper.initialize_config(args)

    eos_token_id = config.eos_token_id
    pad_token_id = config.pad_token_id
    decoder_start_token_id = config.decoder_start_token_id

    encoder_path = os.path.join(args.output, "edinit.onnx")
    decoder_path = os.path.join(args.output, "decoder_past.onnx")
    final_path = os.path.join(args.output, "model_final.onnx")

    encoder_model = onnx.load(encoder_path, load_external_data=True)
    encoder_model.graph.name = "encoderdecoderinit subgraph"
    make_dim_proto_numeric(encoder_model, config)

    decoder_model = onnx.load(decoder_path, load_external_data=True)
    decoder_model.graph.name = "decoder subgraph"
    make_dim_proto_numeric(decoder_model, config)

    inputs = [
        "input_ids",
        "emotion_mask",
        "max_length",
        "min_length",
        "num_beams",
        "num_return_sequences",
        "length_penalty",
        "renormalize_logits",
        "remove_invalid_values",
        "top_k",
        "",
        "",
        "attention_mask",
    ]
    outputs = ["sequences"]

    node = helper.make_node("BeamSearch", inputs=inputs, outputs=outputs, name="BeamSearch_zcode")
    node.domain = "com.microsoft"
    # NOTE: take value from args or config
    node.attribute.extend(
        [
            helper.make_attribute("eos_token_id", eos_token_id),
            helper.make_attribute("pad_token_id", pad_token_id),
            helper.make_attribute("decoder_start_token_id", decoder_start_token_id),
            helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
            helper.make_attribute("early_stopping", args.early_stopping),
            helper.make_attribute("model_type", 1),
            helper.make_attribute("decoder", decoder_model.graph),
            helper.make_attribute("encoder", encoder_model.graph),
        ]
    )

    # graph inputs
    input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
    emotion_mask = helper.make_tensor_value_info("emotion_mask", TensorProto.INT32, ["batch_size", "sequence_length"])
    max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
    min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
    num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
    top_k = helper.make_tensor_value_info("top_k", TensorProto.INT32, [1])
    renormalize_logits = helper.make_tensor_value_info("renormalize_logits", TensorProto.BOOL, [1])
    remove_invalid_values = helper.make_tensor_value_info("remove_invalid_values", TensorProto.BOOL, [1])
    num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
    length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
    attention_mask = helper.make_tensor_value_info(
        "attention_mask", TensorProto.INT32, ["batch_size", "sequence_length"]
    )

    graph_inputs = [
        input_ids,
        emotion_mask,
        max_length,
        min_length,
        num_beams,
        num_return_sequences,
        length_penalty,
        renormalize_logits,
        remove_invalid_values,
        top_k,
        attention_mask
    ]

    # graph outputs
    sequences = helper.make_tensor_value_info(
        "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
    )
    initializers = []
    graph_outputs = [sequences]
    new_graph = helper.make_graph([node], "beam-search-test", graph_inputs, graph_outputs, initializers)

    opset_import = helper.make_opsetid(domain="com.microsoft", version=1)
    # Create the model
    decoder_model.opset_import.append(opset_import)
    new_model = helper.make_model(
        new_graph, producer_name="onnxruntime.transformers", opset_imports=decoder_model.opset_import
    )
    # https://github.com/onnx/onnx/blob/main/onnx/helper.py
    onnx.save(new_model, final_path, save_as_external_data=True, all_tensors_to_one_file=False, convert_attribute=True)
    # check model > 2GB
    print(f"--- Check the model with path: {final_path} ---")
    onnx.checker.check_model(final_path, full_check=True)
    onnx.shape_inference.infer_shapes_path(final_path, strict_mode=True)

Quang-elec44 avatar Jun 16 '23 07:06 Quang-elec44

@tianleiwu It's weird that the beam search input is in the range of 5 and 10, and btw, why there are some empty string inputs?

Quang-elec44 avatar Jun 16 '23 07:06 Quang-elec44

@Quang-elec44, your script is not correct. See https://github.com/microsoft/onnxruntime/blob/rel-1.14.0/docs/ContribOperators.md#com.microsoft.BeamSearch. 5-10 means 5 inputs are required and other are optional. empty string means no input for the optional one.

tianleiwu avatar Jun 16 '23 15:06 tianleiwu

@tianleiwu Oh I see, some inputs in class GenerationConfig are not available exist in the 5-10 inputs. In my script, I have a additional inputs (e.g emotion_mask), which is an encoder input of MBartForConditionalGeneration. I successfully exported the encoder and decoder, but I don't know how to put this input in the beam search. Basically, it's one of the inputs of the generate function. Can you help me out ?

Quang-elec44 avatar Jun 16 '23 15:06 Quang-elec44

@Quang-elec44, you will need change the beamsearch code to support additional inputs: https://github.com/microsoft/onnxruntime/blob/64b22cd00f3af5c07552188773f4ec159f68044b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h#L142-L167 Basically, pass emotion_mask to encoder_feeds and run the encoder subgraph.

tianleiwu avatar Jun 21 '23 06:06 tianleiwu

@tianleiwu Thank you for your information. I manage to change my code in order not to add new inputs and I successfully exported my model. However, there are lots of problems with memory

  • It seems that the exported model/weights take up huge disk space (~4.9GB) while my model consumes only 1.7GB. Could you please tell me the reason? Weights after export image My Pytorch model image

  • 1 instance of the Onnxruntime session consumes 5GB GPU and 7GB CPU while 2 instances of the Pytorch model only consume 4.1GB GPU and 8 GB CPU . Note: I set the following attributes to False: enable_cpu_mem_arena, enable_mem_pattern, enable_mem_reuse

Quang-elec44 avatar Jun 22 '23 03:06 Quang-elec44

@Quang-elec44, the following script (supports GPT-2 and T5 right now) could identify shared initializers in encoder and decoder. Since the Bart onnx model inputs/outputs are similar to T5, you can modify the script to apply on Bart: https://github.com/microsoft/onnxruntime/blob/8e8840f1de86f30304800126224d48c0ed1d2e51/onnxruntime/python/tools/transformers/convert_generation.py#L887

tianleiwu avatar Jun 22 '23 17:06 tianleiwu