TextZoom
TextZoom copied to clipboard
Cannot load trained checkpoints
Pytorch 1.3 used for training (Ok) and for inference. I had a problem with Dataparallel, which I solved by copying multi-gpu model load code in into 1 gpu case:
if self.config.TRAIN.ngpu == 1:
#model.load_state_dict(torch.load(self.resume)['state_dict_G'])
model.load_state_dict(
{'module.' + k: v for k, v in torch.load(self.resume)['state_dict_G'].items()})
However, now when I try to inference on RGB images I have:
python3 main.py --demo --demo_dir='./images/' --resume='ckpt/vis/checkpoint.pth' --STN --mask
Mission.demo()
File "/home/ml/codes/TextZoom/src/interfaces/super_resolution.py", line 299, in demo
images_sr = model(images_lr)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
return self.module(*inputs[0], **kwargs[0])
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/ml/codes/TextZoom/src/model/tsrn.py", line 66, in forward
block = {'1': self.block1(x)}
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
input = module(input)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
result = self.forward(*input, **kwargs)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 345, in forward
return self.conv2d_forward(input, self.weight)
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 4 9 9, expected input[1, 5, 32, 256] to have 4 channels, but got 5 channels instead
When I remove --mask
argument I get weights load error:
Traceback (most recent call last):
File "main.py", line 44, in <module>
main(config, args)
File "main.py", line 16, in main
Mission.demo()
File "/home/ml/codes/TextZoom/src/interfaces/super_resolution.py", line 276, in demo
model_dict = self.generator_init()
File "/home/ml/codes/TextZoom/src/interfaces/base.py", line 161, in generator_init
{'module.' + k: v for k, v in torch.load(self.resume)['state_dict_G'].items()})
File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
size mismatch for module.block1.0.weight: copying a param with shape torch.Size([64, 4, 9, 9]) from checkpoint, the shape in current model is torch.Size([64, 3, 9, 9]).
size mismatch for module.block8.1.weight: copying a param with shape torch.Size([4, 64, 9, 9]) from checkpoint, the shape in current model is torch.Size([3, 64, 9, 9]).
size mismatch for module.block8.1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for module.stn_head.stn_convnet.0.0.weight: copying a param with shape torch.Size([32, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 3, 3, 3]).
Btw, I cannot use the checkpoint you provided at https://github.com/JasonBoy1/TextZoom/issues/25, it says that model.pth is zip archive. Which version of pytorch did you use?
The code works if I comment out mask part:
def transform_(path):
img = Image.open(path)
img = img.resize((256, 32), Image.BICUBIC)
img_tensor = transforms.ToTensor()(img)
'''
if mask_:
mask = img.convert('L')
thres = np.array(mask).mean()
mask = mask.point(lambda x: 0 if x > thres else 255)
mask = transforms.ToTensor()(mask)
img_tensor = torch.cat((img_tensor, mask), 0)
'''
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
But I don't see generated images anywhere...
Resulting images_sr
has dimensions torch.Size([1, 4, 64, 512])
Edit:
Ok, saved everything. Needed to take first 3 channels of the output.
from torchvision.utils import save_image
...
def transform_(path):
img = Image.open(path)
img = img.resize((256, 32), Image.BICUBIC)
img_tensor = transforms.ToTensor()(img)
'''
if mask_:
mask = img.convert('L')
thres = np.array(mask).mean()
mask = mask.point(lambda x: 0 if x > thres else 255)
mask = transforms.ToTensor()(mask)
img_tensor = torch.cat((img_tensor, mask), 0)
'''
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
images = []
for im_name in tqdm(os.listdir(self.args.demo_dir)):
images_sr = model(images_lr)
images.append(images_sr)
images = torch.cat(images, 0)
save_image(images[:,:3,:,:], 'out.png')
I managed to solve the problem, but the inference code seems broken as you see.
I got it working by editing src/model/recognizer/attention_recognition_head.py
, at lines 109 and 141.
109 (see https://github.com/JasonBoy1/TextZoom/issues/16#issue-732964244):
state = state.index_select(1, predecessors.squeeze())
⇣
state = state.index_select(1, predecessors.squeeze().long()) # `.long()` added via #16
141
current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors)
t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze()
⇣
current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors.long())
t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).long().squeeze()
I also changed src/interfaces/base.py
(in the if self.config.TRAIN.ngpu == 1:
block at around line 156):
#model.load_state_dict(torch.load(self.resume)['state_dict_G'])
state_dict = torch.load(self.resume)['state_dict_G']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if 'module' not in k:
k = 'module.'+k
else:
k = k.replace('features.module.', 'module.features.')
new_state_dict[k]=v
model.load_state_dict(new_state_dict)
There were also multiple comparisons against literals with is
and is not
I changed to ==
and !=
after receiving warnings.
The training process got interrupted a few times (my fault!) but this was convenient as it let me just copy the checkpoints for later comparison anyway, I just moved the vis
folder (in both demo/
and ckpt/
) to vis_$iters
where $iters
is the number of iterations (as shown by the name of the directories created in demo/vis/
).
In addition to your suggestion for the demo, I enumerated the images and saved them to separately numbered files:
from pathlib import Path
images = torch.cat(images, 0) # via #30
n_images = len(images)
zf_len = len(str(n_images))
out_dir = Path("results")
out_dir.mkdir(exist_ok=True)
for im_count, out_img in enumerate(images[:, :3, :, :]):
save_image(out_img, out_dir / f"out_{str(im_count).zfill(zf_len)}.png") # via #30
I got it working by editing
src/model/recognizer/attention_recognition_head.py
, at lines 109 and 141.109 (see #16 (comment)):
state = state.index_select(1, predecessors.squeeze())
⇣
state = state.index_select(1, predecessors.squeeze().long()) # `.long()` added via #16
141
current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors) t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze()
⇣
current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors.long()) t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).long().squeeze()
I also changed
src/interfaces/base.py
(in theif self.config.TRAIN.ngpu == 1:
block at around line 156):#model.load_state_dict(torch.load(self.resume)['state_dict_G']) state_dict = torch.load(self.resume)['state_dict_G'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): if 'module' not in k: k = 'module.'+k else: k = k.replace('features.module.', 'module.features.') new_state_dict[k]=v model.load_state_dict(new_state_dict)
There were also multiple comparisons against literals with
is
andis not
I changed to==
and!=
after receiving warnings.The training process got interrupted a few times (my fault!) but this was convenient as it let me just copy the checkpoints for later comparison anyway, I just moved the
vis
folder (in bothdemo/
andckpt/
) tovis_$iters
where$iters
is the number of iterations (as shown by the name of the directories created indemo/vis/
).In addition to your suggestion for the demo, I enumerated the images and saved them to separately numbered files:
from pathlib import Path images = torch.cat(images, 0) # via #30 n_images = len(images) zf_len = len(str(n_images)) out_dir = Path("results") out_dir.mkdir(exist_ok=True) for im_count, out_img in enumerate(images[:, :3, :, :]): save_image(out_img, out_dir / f"out_{str(im_count).zfill(zf_len)}.png") # via #30
Hey ! Can you help me solve this error
I have done the changes which are given in the above code still not able to resume the checkpoint