Serialization problem in pytorch model
I encountered some problem when loading back the serialized pytorch model. (pytorch version 0.4.1)
Traceback (most recent call last): File "/container/pytorch_container.py", line 82, in
rpc_service.get_input_type()) File "/container/pytorch_container.py", line 54, in init self.model = load_pytorch_model(torch_model_path, torch_weights_path) File "/container/pytorch_container.py", line 38, in load_pytorch_model model.load_state_dict(torch.load(weights_path)) AttributeError: 'int' object has no attribute 'load_state_dict'
Using the built-in pytorch serialization solved the problem. (https://pytorch.org/docs/stable/notes/serialization.html)
To solve this problem, I updated the serialization functionin in deployer_utils.py and pytorch deployer. Also I build a new custom image with an updated pytorch_container.py . Please let me know if I should send a PR for this :)
Yes. Please send a PR if you are interested! Thanks a lot!