tch-rs
tch-rs copied to clipboard
How to save a tensor in order to load it with pytorch
I'm considering using tch-rs to preprocess a bunch of data to use in an pytorch pipeline. Thanks for implementing these bindings!
When saving a tch::Tensor like this:
// tch = "0.6.1"
use tch::Tensor;
fn main() {
Tensor::of_slice(&[1,2,3])
.save("test.pt")
.expect("unable to save tensor");
}
Loading it with torch.load in python results in an error:
% python
Python 3.8.3 (default, Jul 2 2020, 11:26:31)
[Clang 10.0.0 ] :: Anaconda, Inc. on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.__version__
'1.10.0'
>>> torch.load("test.pt")
/Users/user/opt/anaconda3/lib/python3.8/site-packages/torch/serialization.py:602: UserWarning: 'torch.load' received a zip file that looks like a TorchScript archive dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to silence this warning)
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
RecursiveScriptModule(original_name=Module)
Loading it with torch.jit.load appears to work like this:
>>> torch.jit.load('test.pt').state_dict()['0']
tensor([1, 2, 3], dtype=torch.int32)
Is this the proper way to save and load data between tch-rs and pytorch? Or is there a way to save tch::Tensor so that I can load it with torch.load?
torch.save and torch.load use python-specific pickle modules to marshall binary data. The scripted modules (as used here) are cross-platform. With some hacking you could probably make serde do what you want... But not positive. Since this is a wrapper around libtorch, you might not get the python frontend APIs like pickle serialization.
tch now supports the safetensors format which work well for interop with the python side so I would recommend using this if you can. Closing this as the issue had no activity in a while but feel free to re-open if it's still a problem.