early-stopping-pytorch
early-stopping-pytorch copied to clipboard
check if model is instance of DataParallel before saving checkpoint
Problem
In current save_checkpoint function, if the model is on multiple GPUs, i.e. model is a instance of torch.nn.DataParallel, and then the saved checkpoint could not be loaded again.
Describe your changes
Follow the pytorch tutorial, first the current model is checked, if it is a instance of DataParallel class.
If the model is a instance of DataParallel class, then model.module.state_dict() is saved instead of model.state_dict() in current implementation.