avalanche
avalanche copied to clipboard
Add a method to initialize dynamic modules from state_dict
Hi,
Training
I initialize net=MTSimpleCNN() and train this multi-task model with multi-head classifier on two tasks (one task per experience). After training the model has two heads.
I save the model this model's state dict using standard pytorch torch.save({'state_dict': net.state_dict()}, filename). and I want to evaluate it later using a different script.
Evaluation
I newly initialize the model net=MTSimpleCNN() and load the weights using standard pytorch net.load_state_dict(torch.load(filename,map_location=device)['state_dict']). However the saved model state dict has two heads and the newly initialized one only one head. Keys don't match so I get this error:
Error(s) in loading state_dict for MTSimpleCNN: Unexpected key(s) in state_dict: "classifier.active_units_T0", "classifier.active_units_T1", "classifier.classifiers.1.active_units", "classifier.classifiers.1.classifier.weight", "classifier.classifiers.1.classifier.bias", "classifier.classifiers.0.active_units"
is there any easy way to initialize this multi-head network with arbitrary number of heads?
Cheers, Woj
You have to adapt the architecture before using load_state_dict. Here is how I do it:
def load_mt_model(model_fname, stream):
model = ResNet18(100)
model = as_multitask(model, 'linear')
for e in list(stream):
avalanche_model_adaptation(model, e)
# print(model)
model.load_state_dict(torch.load(model_fname, map_location=torch.device('cpu')))
return model.eval()
I'm leaving this issue open because we should provide this functionality somewhere.
@AntonioCarta I might work on that next. I think we should indeed implement a load_state_dict function for the dynamicmodules that automatically adapts the module. That would simplify a lot of things, and make it easier to use DynamicModules out of the box without the need to reload the whole scenario.
Maybe we have to wait for this issue: https://github.com/pytorch/pytorch/issues/75287 to be solved. It seems like these load_state_dict pre hooks would be the way to go, but they are implemented as private functions right now. Since the public method load_state_dict itself is not a function that is recursively called on modules, we cannot simply override load_state_dict of DynamicModules as I initially thought. A solution for now could be to use the private API and pray that it is consistent between supported torch versions, to later switch to the public API if it's created someday.
@AlbinSou can we provide an external function like avalanche_model_adaptation? Then, once the load-state-dict hooks are implemented in pytorch we can integrate that solution better with the nn.Module methods.