SemTorch
SemTorch copied to clipboard
Can't load HRNet Segmentation weights from PTH file
With get_segmentation_learner(architecture_name='hrnet', backbone_name='hrnet_w18')
.
Using the following callback to save the models during training:
SaveModelCallback(monitor='dice_multi', fname='best_model', with_opt=True)
The results of the predictions after loading "best_model.pth" with learner.load
are zero-filled masks.
The prediction using the learner right after training are correct.
Hi, sorry, for super late reply. Where you using Windows?
No, we were using Google Collab Notebooks (Ubuntu 18.04.3 LTS 64-bit).
w32 & w48 not work, either.
PyTorch 1.12, Python 3.9 on Paperspace
learn = get_segmentation_learner(dls=dls, number_classes=2, segmentation_type="Semantic Segmentation",
architecture_name="hrnet", backbone_name="hrnet_w32",
splitter=segmentron_splitter,
loss_func=CustomLoss(),
metrics=[Dice, foreground_acc, JaccardCoeff],
wd=1e-3).to_fp16()
RuntimeError Traceback (most recent call last)
File /usr/local/lib/python3.9/dist-packages/semtorch/models/archs/backbones/build.py:51, in load_backbone_pretrained(model, backbone)
49 weights_path = download(model_urls[backbone], path=weights_path)
---> 51 msg = model.init_weights(pretrained=weights_path)
52 else:
File /usr/local/lib/python3.9/dist-packages/semtorch/models/archs/backbones/hrnet.py:475, in HighResolutionNet.init_weights(self, pretrained)
474 model_dict.update(pretrained_dict)
--> 475 self.load_state_dict(model_dict)
476 return "HRNet backbone wieghts loaded"
File /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1604, in Module.load_state_dict(self, state_dict, strict)
1603 if len(error_msgs) > 0:
-> 1604 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1605 self.__class__.__name__, "\n\t".join(error_msgs)))
1606 return _IncompatibleKeys(missing_keys, unexpected_keys)
complete error message. hrnet-w32-error.txt
After deleting the previous cache file, the notebook can load the hrnet_w32 weights. It seems the PTH cache will always be ~/.cache/torch/checkpoints
.