OFA icon indicating copy to clipboard operation
OFA copied to clipboard

Question about saving OFA as a jit script

Open 25icecreamflavors opened this issue 2 years ago • 4 comments

Hello, I want to use OFA and need to save it.

  1. Can I somehow save it via torch jit script? I tried, but there is an error.
  2. Also, is it possible to use OFA in tensoflow and save the model there somehow?

When I save via torch script, I get the following error:

RuntimeError: 
'Tensor' object has no attribute or method 'bool'.:
  File "/content/OFA/models/ofa/unify_transformer.py", line 544
        h, w = image_embed.shape[-2:]
        image_num_patches = h * w
        image_padding_mask = patch_images.new_zeros((patch_images.size(0), image_num_patches)).bool()
                             ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

25icecreamflavors avatar Jul 25 '22 03:07 25icecreamflavors

Never did this before, but I'll try to save it with torch jit script later. I wonder where the error comes from, would you mind telling me more details, say the version of pytorch, etc. ?

JustinLin610 avatar Jul 25 '22 08:07 JustinLin610

Thank you. Basically I just run your demo colab notebook for the image captioning. The torch version is 1.12.0+cu113. I ran all cells and in the end I just do:

saved_model = torch.jit.script(model)
saved_model.save('saved_model.pt')

25icecreamflavors avatar Jul 25 '22 19:07 25icecreamflavors

Seemingly the problem comes from here... https://github.com/pytorch/pytorch/issues/70544

JustinLin610 avatar Jul 26 '22 08:07 JustinLin610

Thank you. Basically I just run your demo colab notebook for the image captioning. The torch version is 1.12.0+cu113. I ran all cells and in the end I just do:

saved_model = torch.jit.script(model)
saved_model.save('saved_model.pt')

Hi, the main problem here is jit.script do not support the bool operation of Tensor currently. A simple workaround is just to change the some_tensor.new_zeros(xx, yy).bool() to some_tensor.new_zeros(xx, yy, dtype=torch.bool). You could modify the ofa code locally to solve it temporarily. Maybe we could add support for the jit script in the next version? @JustinLin610

geekinglcq avatar Jul 27 '22 12:07 geekinglcq