PytorchWCT
PytorchWCT copied to clipboard
Fix: Read Lua weights with Pytorch > 1.0
from torch.utils.serialization import load_lua doesnt work in current pytorch versions
here is a possible fix with torchfile
class pytorch_lua_wrapper:
def __init__(self, lua_path):
self.lua_model = torchfile.load(lua_path)
def get(self, idx):
return self.lua_model._obj.modules[idx]._obj
Now you can relace this line:
vgg1 = load_lua(args.vgg1)
with
vgg1 = pytorch_lua_wrapper(args.vgg1)
and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())
Thanks, but get the error as follows
File "G:\Project\A_NST\PytorchWCT-master\util.py", line 28, in __init__
vgg1 = pytorch_lua_wrapper(args.vgg1)
File "G:\Project\A_NST\PytorchWCT-master\util.py", line 18, in __init__
self.lua_model = torchfile.load(lua_path)
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 424, in load
return reader.read_obj()
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 370, in read_obj
obj._obj = self.read_obj()
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 385, in read_obj
k = self.read_obj()
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 386, in read_obj
v = self.read_obj()
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 370, in read_obj
obj._obj = self.read_obj()
File "D:\anaconda3\envs\mypytorch\lib\site-packages\torchfile.py", line 387, in read_obj
obj[k] = v
TypeError: unhashable type: 'list'
solution is here https://github.com/bshillingford/python-torchfile/issues/12