torch icon indicating copy to clipboard operation
torch copied to clipboard

Fail to load .pt file generated with torch.save in python

Open 903124 opened this issue 3 years ago • 6 comments

When torch.load a .pt tensor file generated in R to python it fails:

UnpicklingError: invalid load key, '\x1f'.

903124 avatar Jun 04 '21 04:06 903124

It depends on how you generated the file. If you used torch_save this is expected, because torch_save actually saved as an RDS file with small differences to correctly serialize the tensors.

You can see this example: https://blogs.rstudio.com/ai/posts/2020-12-15-torch-0.2.0-released/#initial-jit-support and this one https://community.rstudio.com/t/r-model-serving-using-python-torchserve/94303/2 on how to save in the correct format to be loaded by python.

This is still a work in progress and will improve as we add more support to JIT.

dfalbel avatar Jun 05 '21 02:06 dfalbel

Well so in order to save it to python right now I've to jit_trace a function which do nothing (e.g. minus 0)?

903124 avatar Jun 05 '21 02:06 903124

basically you need to jit_trace a function that just calls your model. but before that you have to detach the model parameters..

dfalbel avatar Jun 05 '21 02:06 dfalbel

Cool I changed the title and leave it open then

903124 avatar Jun 05 '21 02:06 903124

Sorry, I don't follow...

Are you trying to save in python and load in R or save in R and load into python? Python torch.save is also not compatible with R's torch_load(), there's a note about this here: https://torch.mlverse.org/docs/articles/serialization.html#loading-models-saved-in-python-1

You can also jit_load a python model saved with torch.jit.save.

If you are trying to use torch_save from R and load it into python with torch.load that's not supported by design. The alternative is using jit_trace/ jit_save but this is still a work in progress.

dfalbel avatar Jun 05 '21 12:06 dfalbel

I'm trying to save a tensor from R and load it from python and not much luck with jit. I can convert it to array and turn into numpy object with reticulate but that's extra steps

903124 avatar Jun 05 '21 13:06 903124