onnxruntime icon indicating copy to clipboard operation
onnxruntime copied to clipboard

Computing loss within onnxrunitme inference (GPT2 model)

Open OriAlpha opened this issue 3 years ago • 4 comments

This could be a feature, i am trying to compute language model metries. i.e., PERPLEXITY score. In pytorch it is convenient to get score by

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)

inputs = tokenizer("here is an example of gpt2 model", return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss

But in onnx its not clear how can i compute this error??

System information

  • ONNX Runtime version (you are using):
  • Linux: 20.04
  • Onnxruntime: 1.12.1
  • Python: 3.8

I have seen a post regarding this: https://github.com/microsoft/onnxruntime/issues/8675

By the suggested issue: I added labels=input_ids

https://github.com/microsoft/onnxruntime/blob/064a385b59598419c287a4b897afca376227d3bd/onnxruntime/python/tools/transformers/gpt2_helper.py#L84-L88

  def forward(self, input_ids, position_ids, attention_mask, *past):
      result = super().forward(
          input_ids,
          labels=input_ids,
          position_ids=position_ids,
          attention_mask=attention_mask,
          past_key_values=past,
          return_dict=False,
      )

AND in export_onnx function:
https://github.com/microsoft/onnxruntime/blob/064a385b59598419c287a4b897afca376227d3bd/onnxruntime/python/tools/transformers/gpt2_helper.py#L265

  def export_onnx(
      model,
      device,
      onnx_model_path: str,
      verbose: bool = False,
      use_external_data_format: bool = False,
      has_position_ids: bool = True,
      has_attention_mask: bool = True,
      input_ids_dtype: torch.dtype = torch.int32,
      position_ids_dtype: torch.dtype = torch.int32,
      attention_mask_dtype: torch.dtype = torch.int32,
      labels: torch.dtype = torch.int32,
  ):

But now i am ended into weird error as [E:onnxruntime:, sequential_executor.cc:368 Execute] Non-zero status code returned while running GatherElements node. Name:'GatherElements_9918' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/tensor/gather_elements.cc:154 void onnxruntime::core_impl(const onnxruntime::Tensor*, const onnxruntime::Tensor*, onnxruntime::Tensor*, int64_t, onnxruntime::concurrency::ThreadPool*) [with Tin = long int; int64_t = long int] GatherElements op: Out of range value in index tensor

OriAlpha avatar Aug 09 '22 17:08 OriAlpha

I think you will need export a model by adding labels to inputs and loss to the outputs. See corresponding part in python: https://github.com/huggingface/transformers/blob/8cf4a6f0a63ed3aeed68192a9304fed2bd0ce100/src/transformers/models/gpt2/modeling_gpt2.py#L1087-L1094

You will need update the interface like

 def forward(self, labels, input_ids, position_ids, attention_mask, *past):
       result = super().forward(
          input_ids,
          labels=labels,
          position_ids=position_ids,
          attention_mask=attention_mask,
          past_key_values=past,
          return_dict=False,
      )
      // TODO: get loss from result, and other output fields. You can set a break point here in debugger.
      loss = get_loss(result)
      return loss, other_results

When the logic is exported to ONNX, you shall see some nodes added to the graph to compute loss based on logits and labels.

You can also export loss calculation to another ONNX model with logits and labels as inputs, and loss as outputs. Just warp the above python code to a class, then export it to onnx.

tianleiwu avatar Aug 10 '22 00:08 tianleiwu

ok, i modified the class as (i.e., onnxruntime/onnxruntime/python/tools/transformers/gpt2_helper.py)

class MyGPT2LMHeadModel(GPT2LMHeadModel):
    """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state."""

    def __init__(self, config):
        super().__init__(config)

    def forward(self, labels, input_ids, position_ids, attention_mask, *past):
    #def forward(self, labels, input_ids, position_ids, attention_mask, *past):
        result = super().forward(
            input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past,
            return_dict=False,
        )
        loss = get_loss(result)

        return MyGPT2Model.post_process(loss, result, self.config.n_layer)
        #return MyGPT2Model.post_process(loss, result, self.config.n_layer)

But still its same issue?? Do i have to import get_loss() function or recreate from transformers??

OriAlpha avatar Aug 10 '22 07:08 OriAlpha

@OriAlpha, see the related code: https://github.com/huggingface/transformers/blob/8cf4a6f0a63ed3aeed68192a9304fed2bd0ce100/src/transformers/models/gpt2/modeling_gpt2.py#L1096-L1107 The first tuple is the loss. So we can def get_loss(result): return result[0]

If you use return_dict=True, the result contains a field called loss.

tianleiwu avatar Aug 11 '22 20:08 tianleiwu

yes, but i need to import into onnxruntime this loss behaviour as i mentioned above?? @tianleiwu do i need to edit function in gpt2_beamsearch_helper.py or gpt2_helper.py, i am bit confused now

OriAlpha avatar Aug 11 '22 20:08 OriAlpha

@OriAlpha, you will need other changes: For example, create dummy inputs need a real tensor for labels. Need add dynamic axes setting for labels and loss. If you are familiar with torch.onnx.export, it is not hard. You can modify those in gpt2_helper.py (For each place of input_ids, you might need add similar logic for labels input; for each place of logits output, you might need add similar logic for loss output).

tianleiwu avatar Aug 16 '22 06:08 tianleiwu

If possible could @tianleiwu provide a code snippet??, regrading computing loss.

OriAlpha avatar Aug 21 '22 16:08 OriAlpha