transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Pytorch MBart Model - Trace on CPU and run inference on GPU.

Open gnovack opened this issue 1 year ago • 1 comments

System Info

  • transformers version: 4.26.1
  • Platform: Linux-5.10.157-139.675.amzn2.x86_64-x86_64-with-glibc2.26
  • Python version: 3.9.15
  • Huggingface_hub version: 0.13.0
  • PyTorch version (GPU?): 1.13.1 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

Information

  • [ ] The official example scripts
  • [x] My own modified scripts

Tasks

  • [ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

  1. Load MBart model and trace it on CPU with torch.jit.trace()
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer

tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50", src_lang="en_XX", tgt_lang="ro_RO")
example_english_phrase = "UN Chief Says There Is No Military Solution in Syria"
expected_translation_romanian = "Şeful ONU declară că nu există o soluţie militară în Siria"

inputs = tokenizer(example_english_phrase, text_target=expected_translation_romanian, return_tensors="pt")

model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50", torchscript=True)

traced_model = torch.jit.trace(model, [inputs.input_ids, inputs.attention_mask])
torch.jit.save(traced_model, "mbart-traced.pt")
  1. Load traced model and place it on GPU using torch.jit.load()
loaded_model_gpu = torch.jit.load("mbart-traced.pt", map_location=torch.device('cuda'))
  1. Run inference on GPU
loaded_model_gpu(inputs.input_ids.to('cuda'), inputs.attention_mask.to('cuda'))

The following error is raised while running inference:

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/transformers/models/mbart/modeling_mbart/___torch_mangle_1394.py", line 15, in forward
    lm_head = self.lm_head
    model = self.model
    _0 = (model).forward(input_ids, attention_mask, )
          ~~~~~~~~~~~~~~ <--- HERE
    _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, = _0
    _51 = torch.add((lm_head).forward(_1, ), final_logits_bias)
  File "code/__torch__/transformers/models/mbart/modeling_mbart/___torch_mangle_1392.py", line 31, in forward
    _7 = torch.slice(prev_output_tokens0, 0, 0, 9223372036854775807)
    _8 = torch.fill_(torch.select(_7, 1, 0), decoder_start_tokens)
    _9 = (encoder).forward(embed_tokens, weight, input_ids, attention_mask, )
          ~~~~~~~~~~~~~~~~ <--- HERE
    _10 = (decoder).forward(weight, prev_output_tokens0, attention_mask, _9, )
    _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, = _10
  File "code/__torch__/transformers/models/mbart/modeling_mbart/___torch_mangle_1181.py", line 47, in forward
    _13 = (argument_1).forward(weight, input, )
    inputs_embeds = torch.mul(_13, CONSTANTS.c1)
    _14 = (embed_positions).forward(input_ids, )
           ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    input0 = torch.add(inputs_embeds, _14)
    _15 = (layernorm_embedding).forward(input0, )
  File "code/__torch__/transformers/models/mbart/modeling_mbart/___torch_mangle_1045.py", line 17, in forward
    positions = torch.expand(_2, [_0, -1])
    input = torch.add(positions, CONSTANTS.c3)
    return torch.embedding(weight, input)
           ~~~~~~~~~~~~~~~ <--- HERE

Traceback of TorchScript, original code (most recent call last):
...
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)

Also, by running dump_to_str I am able to see that device is set to cpu within MBartLearnedPositionalEmbedding:

>>> loaded_model_gpu._c.dump_to_str(True, False, False)
module __torch__.transformers.models.mbart.modeling_mbart.___torch_mangle_4565.MBartLearnedPositionalEmbedding {
  parameters {
    weight = ...
  }
  attributes {
    weight = ...
    training = False
    _is_full_backward_hook = None
  }
  methods {
    method forward {
      graph(%self.1 : __torch__.transformers.models.mbart.modeling_mbart.___torch_mangle_4565.MBartLearnedPositionalEmbedding,
            %input_ids.1 : Tensor):
        %34 : Tensor = prim::Constant[value={2}]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:133:0
        %25 : bool = prim::Constant[value=0]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:129:0
        %52 : Device = prim::Constant[value="cpu"]()
        %22 : NoneType = prim::Constant() # :0:0
        %16 : Tensor = prim::Constant[value={0}]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:130:0
        %5 : int = prim::Constant[value=0]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:128:0
        %12 : int = prim::Constant[value=1]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:128:0
        %21 : int = prim::Constant[value=4]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:129:0
        %29 : int = prim::Constant[value=-1]() # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:129:0
        %weight.1 : Tensor = prim::GetAttr[name="weight"](%self.1)
        %6 : int = aten::size(%input_ids.1, %5) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:128:0
        %bsz.1 : Tensor = prim::NumToTensor(%6) # :0:0
        %10 : int = aten::Int(%bsz.1) # :0:0
        %13 : int = aten::size(%input_ids.1, %12) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:128:0
        %seq_len.1 : Tensor = prim::NumToTensor(%13) # :0:0
        %18 : Tensor = aten::add(%seq_len.1, %16, %12) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:130:0
        %19 : Scalar = aten::ScalarImplicit(%18) # :0:0
        %26 : Tensor = aten::arange(%5, %19, %21, %22, %52, %25) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:129:0
        %30 : int[] = prim::ListConstruct(%10, %29)
        %positions.1 : Tensor = aten::expand(%26, %30, %25) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:129:0
        %input.1 : Tensor = aten::add(%positions.1, %34, %12) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py:133:0
        %42 : Tensor = aten::embedding(%weight.1, %input.1, %29, %25, %25) # /home/gnovack/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages/torch/nn/functional.py:2210:0
        return (%42)
  
    }
  }
  submodules {
  }
}

Expected behavior

I expected to be able to run inference successfully on GPU.

I have come across some similar issues related to other types of models:

  • https://github.com/huggingface/transformers/issues/5664
  • https://github.com/pytorch/pytorch/issues/50971

And some PRs to address some similar issues:

  • https://github.com/huggingface/transformers/pull/11252
  • https://github.com/huggingface/transformers/pull/12290

gnovack avatar Mar 08 '23 19:03 gnovack

cc @ArthurZucker and @younesbelkada

sgugger avatar Mar 08 '23 20:03 sgugger

EDIT: in order to actually solve this, we would need a lot of potential usage. The reason is that after fixing the positional ids with a registered buffer we need to modify the causal attention mask which also has to be a buffer otherwise it does not work. This is a lot of refactoring on a lot of model (even if we juste fix this one, it is still a bit too much): we would have to implement the same logic as in GPT2 and GPTNeo.

ArthurZucker avatar Apr 11 '23 15:04 ArthurZucker

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar May 06 '23 15:05 github-actions[bot]