early-stopping-pytorch icon indicating copy to clipboard operation
early-stopping-pytorch copied to clipboard

check if model is instance of DataParallel before saving checkpoint

Open jiayangshi opened this issue 3 years ago • 0 comments

Problem

In current save_checkpoint function, if the model is on multiple GPUs, i.e. model is a instance of torch.nn.DataParallel, and then the saved checkpoint could not be loaded again.

Describe your changes

Follow the pytorch tutorial, first the current model is checked, if it is a instance of DataParallel class. If the model is a instance of DataParallel class, then model.module.state_dict() is saved instead of model.state_dict() in current implementation.

jiayangshi avatar Mar 31 '22 10:03 jiayangshi