TextZoom icon indicating copy to clipboard operation
TextZoom copied to clipboard

Cannot load trained checkpoints

Open hcl14 opened this issue 3 years ago • 4 comments

Pytorch 1.3 used for training (Ok) and for inference. I had a problem with Dataparallel, which I solved by copying multi-gpu model load code in into 1 gpu case:

if self.config.TRAIN.ngpu == 1:
                    #model.load_state_dict(torch.load(self.resume)['state_dict_G'])
                    model.load_state_dict(
                        {'module.' + k: v for k, v in torch.load(self.resume)['state_dict_G'].items()})

However, now when I try to inference on RGB images I have:

python3 main.py --demo --demo_dir='./images/' --resume='ckpt/vis/checkpoint.pth' --STN --mask

    Mission.demo()
  File "/home/ml/codes/TextZoom/src/interfaces/super_resolution.py", line 299, in demo
    images_sr = model(images_lr)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/codes/TextZoom/src/model/tsrn.py", line 66, in forward
    block = {'1': self.block1(x)}
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 345, in forward
    return self.conv2d_forward(input, self.weight)
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 342, in conv2d_forward
    self.padding, self.dilation, self.groups)
RuntimeError: Given groups=1, weight of size 64 4 9 9, expected input[1, 5, 32, 256] to have 4 channels, but got 5 channels instead

When I remove --mask argument I get weights load error:

Traceback (most recent call last):
  File "main.py", line 44, in <module>
    main(config, args)
  File "main.py", line 16, in main
    Mission.demo()
  File "/home/ml/codes/TextZoom/src/interfaces/super_resolution.py", line 276, in demo
    model_dict = self.generator_init()
  File "/home/ml/codes/TextZoom/src/interfaces/base.py", line 161, in generator_init
    {'module.' + k: v for k, v in torch.load(self.resume)['state_dict_G'].items()})
  File "/home/ml/anaconda3/envs/tf15_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DataParallel:
        size mismatch for module.block1.0.weight: copying a param with shape torch.Size([64, 4, 9, 9]) from checkpoint, the shape in current model is torch.Size([64, 3, 9, 9]).
        size mismatch for module.block8.1.weight: copying a param with shape torch.Size([4, 64, 9, 9]) from checkpoint, the shape in current model is torch.Size([3, 64, 9, 9]).
        size mismatch for module.block8.1.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([3]).
        size mismatch for module.stn_head.stn_convnet.0.0.weight: copying a param with shape torch.Size([32, 4, 3, 3]) from checkpoint, the shape in current model is torch.Size([32, 3, 3, 3]).

Btw, I cannot use the checkpoint you provided at https://github.com/JasonBoy1/TextZoom/issues/25, it says that model.pth is zip archive. Which version of pytorch did you use?

hcl14 avatar Apr 01 '21 15:04 hcl14

The code works if I comment out mask part:

        def transform_(path):
            img = Image.open(path)
            img = img.resize((256, 32), Image.BICUBIC)
            img_tensor = transforms.ToTensor()(img)
            '''
            if mask_:
                mask = img.convert('L')
                thres = np.array(mask).mean()
                mask = mask.point(lambda x: 0 if x > thres else 255)
                mask = transforms.ToTensor()(mask)
                img_tensor = torch.cat((img_tensor, mask), 0)
            '''
            img_tensor = img_tensor.unsqueeze(0)
            return img_tensor

But I don't see generated images anywhere... Resulting images_sr has dimensions torch.Size([1, 4, 64, 512])

Edit:

Ok, saved everything. Needed to take first 3 channels of the output.

from torchvision.utils import save_image

...

def transform_(path):
            img = Image.open(path)
            img = img.resize((256, 32), Image.BICUBIC)
            img_tensor = transforms.ToTensor()(img)
            '''
            if mask_:
                mask = img.convert('L')
                thres = np.array(mask).mean()
                mask = mask.point(lambda x: 0 if x > thres else 255)
                mask = transforms.ToTensor()(mask)
                img_tensor = torch.cat((img_tensor, mask), 0)
            '''
            img_tensor = img_tensor.unsqueeze(0)
            return img_tensor

images = []
for im_name in tqdm(os.listdir(self.args.demo_dir)):

            images_sr = model(images_lr)
            images.append(images_sr)


images = torch.cat(images, 0)
save_image(images[:,:3,:,:], 'out.png')

I managed to solve the problem, but the inference code seems broken as you see.

hcl14 avatar Apr 01 '21 15:04 hcl14

I got it working by editing src/model/recognizer/attention_recognition_head.py, at lines 109 and 141.

109 (see https://github.com/JasonBoy1/TextZoom/issues/16#issue-732964244):

    state = state.index_select(1, predecessors.squeeze())

    state = state.index_select(1, predecessors.squeeze().long()) # `.long()` added via #16

141

    current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors)
    t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze()

    current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors.long())
    t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).long().squeeze()

I also changed src/interfaces/base.py (in the if self.config.TRAIN.ngpu == 1: block at around line 156):

                  #model.load_state_dict(torch.load(self.resume)['state_dict_G'])
                  state_dict = torch.load(self.resume)['state_dict_G']
                  from collections import OrderedDict
                  new_state_dict = OrderedDict()

                  for k, v in state_dict.items():
                      if 'module' not in k:
                          k = 'module.'+k
                      else:
                          k = k.replace('features.module.', 'module.features.')
                      new_state_dict[k]=v

                  model.load_state_dict(new_state_dict)

There were also multiple comparisons against literals with is and is not I changed to == and != after receiving warnings.

The training process got interrupted a few times (my fault!) but this was convenient as it let me just copy the checkpoints for later comparison anyway, I just moved the vis folder (in both demo/ and ckpt/) to vis_$iters where $iters is the number of iterations (as shown by the name of the directories created in demo/vis/).

In addition to your suggestion for the demo, I enumerated the images and saved them to separately numbered files:

from pathlib import Path

      images = torch.cat(images, 0) # via #30
      n_images = len(images)
      zf_len = len(str(n_images))
      out_dir = Path("results")
      out_dir.mkdir(exist_ok=True)
      for im_count, out_img in enumerate(images[:, :3, :, :]):
          save_image(out_img, out_dir / f"out_{str(im_count).zfill(zf_len)}.png") # via #30

lmmx avatar Apr 08 '21 16:04 lmmx

I got it working by editing src/model/recognizer/attention_recognition_head.py, at lines 109 and 141.

109 (see #16 (comment)):

    state = state.index_select(1, predecessors.squeeze())

    state = state.index_select(1, predecessors.squeeze().long()) # `.long()` added via #16

141

    current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors)
    t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze()

    current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors.long())
    t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).long().squeeze()

I also changed src/interfaces/base.py (in the if self.config.TRAIN.ngpu == 1: block at around line 156):

                  #model.load_state_dict(torch.load(self.resume)['state_dict_G'])
                  state_dict = torch.load(self.resume)['state_dict_G']
                  from collections import OrderedDict
                  new_state_dict = OrderedDict()

                  for k, v in state_dict.items():
                      if 'module' not in k:
                          k = 'module.'+k
                      else:
                          k = k.replace('features.module.', 'module.features.')
                      new_state_dict[k]=v

                  model.load_state_dict(new_state_dict)

There were also multiple comparisons against literals with is and is not I changed to == and != after receiving warnings.

The training process got interrupted a few times (my fault!) but this was convenient as it let me just copy the checkpoints for later comparison anyway, I just moved the vis folder (in both demo/ and ckpt/) to vis_$iters where $iters is the number of iterations (as shown by the name of the directories created in demo/vis/).

In addition to your suggestion for the demo, I enumerated the images and saved them to separately numbered files:

from pathlib import Path

      images = torch.cat(images, 0) # via #30
      n_images = len(images)
      zf_len = len(str(n_images))
      out_dir = Path("results")
      out_dir.mkdir(exist_ok=True)
      for im_count, out_img in enumerate(images[:, :3, :, :]):
          save_image(out_img, out_dir / f"out_{str(im_count).zfill(zf_len)}.png") # via #30

Hey ! Can you help me solve this error image

kanika02 avatar Oct 26 '21 09:10 kanika02

I have done the changes which are given in the above code still not able to resume the checkpoint

kanika02 avatar Oct 26 '21 09:10 kanika02