deeprank
deeprank copied to clipboard
Loading pretrained model on CPU using `development` branch
When loading a model that has been trained on GPU and loading it on CPU the line 230-240 of NeuralNet.py removes the state_dict keys and therefore we can't load the model.
# load parameters of pretrained model if provided
if self.pretrained_model:
# a prefix 'module.' is added to parameter names if
# torch.nn.DataParallel was used
# https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
if self.state['cuda']:
for paramname in list(self.state['state_dict'].keys()):
paramname_new = paramname.lstrip('module.')
self.state['state_dict'][paramname_new] = \
self.state['state_dict'][paramname]
del self.state['state_dict'][paramname]
self.load_model_params()
I'm not even sure if adding the module. is necessary anymore