torch icon indicating copy to clipboard operation
torch copied to clipboard

Saving and loading checkpoint model for resuming training in R (using Torch)

Open dominikj2 opened this issue 4 years ago • 3 comments

Hi, I’m not able to save and load a checkpoint model for resuming training in R. I am only able to work out an R procedure that saves the entire model and not specific “state_dict”. When resuming training, I start off with the correct loss value at a given epoch but each epoch thereafter results in identical loss values (i.e training does not update optimised parameters and does not change loss).

In Pytorch, it is encouraged to save the model’s state_dict as well as the “optimizer.state_dict()” and “scheduler.state_dict()”. Pytorch does not recommend this:

torch.save(model, 'save/to/path/model.pt')
model = torch.load('load/from/path/model.pt')

But recommends this:

torch.save({
            'epoch': EPOCH,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'loss': LOSS,
            }, 'save/to/path/model.pth')

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('load/from/path/model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

In R, my understanding is that it is not possible to save multiple components and organize them in a dictionary (or list) using torch.save() to serialize.

I am not sure what the alternative approach may be. I am not able to find an R version of “optimizer.load_state_dict” and “scheduler.load_state_dict”. Does torch_save() contain all the state_dict and if so how do I access and use correst state_dict to resume training? For example, the optimiser objects: optimizer <- optim_sgd(model$parameters(), …) and optimizer <- optim_sgd(model$state_dict(), …) , only has optimizer$state which is an empty list(), but there is no optimizer$state_dict

Also, I’m using lr_one_cycle scheduler for optimising the learning rate. After loading the model I use below code:

optimizer <- optim_sgd(model$parameters, lr= OPTIMAL_LR,   momentum=Para_momentum) # model_Instance_Para
  scheduler <- optimizer %>%
    lr_one_cycle(max_lr = OPTIMAL_LR, epochs = Para_num_epochs, steps_per_epoch = Train_dl$.length())

The schedulere has a scheduler$state_dict() but I don’t know how to use it.

Below is my pseudo-code:

- model <- torch_load(checkpoint_fpath)
- Model_State_Dict <- model$parameters()  # I also tried Model_State_Dict <- model$state_dict()
- optimizer <- optim_sgd(Model_State_Dict)
- Define criterion
- EPOCH LOOP
     -  model$train()
     - for b in enumerate(Train_dl)
            -  optimizer$zero_grad()
            -  compute multitaskloss
            -  multitaskloss$backward()     
            -  optimizer$step()    
            -  scheduler$step()  # NOT SURE WHERE STATE_DICT IS FOR THIS
     - torch_save(model) # THIS IS WHAT I LOAD FOR RESUMING TRAINING
     - model$eval()
     - for bb in enumerate(Valid_dl) 
           - compute multitaskloss

Am I misunderstanding something? Any help would be appreciated. Thank you.

dominikj2 avatar Jun 09 '21 00:06 dominikj2

Hi @dominikj2, your understanding but you are hitting a few bugs and missing features. A few of them have already been fixed in the dev version and will come to CRAN soon.

I start off with the correct loss value at a given epoch but each epoch thereafter results in identical loss values (i.e training does not update optimised parameters and does not change loss).

This is indeed a bug (probably the same as #559 and was fixed in #585. #585 also adds load_state_dict() and state_dict() to optimizers.

In R, my understanding is that it is not possible to save multiple components and organize them in a dictionary (or list) using torch.save() to serialize.

This was also implemented yesterday in #586

So AFAICT the current dev version of torch should fix all these issues. You can install it with remotes::install_github("mlverse/torch")

dfalbel avatar Jun 09 '21 10:06 dfalbel

Hi @dfalbel,

Thank you for updating load_state_dict() and state_dict() for loading saved model parameters and optimiser parameter.

I am now only unclear on how the scheduler parameters (i.e. lr_one_cycle) are saved and loaded.

scheduler <- optimizer %>%
      lr_one_cycle(max_lr = OPTIMAL_LR, epochs = Para_num_epochs, steps_per_epoch = Train_dl$.length())

The above scheduler$state_dict() gives the following error:

Error in dict[[-which(names(dict) == "optimizer")]] : 
  attempt to select more than one element in integerOneIndex

scheduler$state_dict is the following function:

function () 
{
    dict <- as.list(self)
    dict <- dict[[-which(names(dict) == "optimizer")]]
    dict
}
<environment: 0x00000243738a9ae8>

The following link suggests that in Python loading schedulers should be very similar to loading models and optimisers . https://stackoverflow.com/questions/67119827/how-to-load-a-learning-rate-scheduler-state-dict

Any suggestions?

dominikj2 avatar Jun 11 '21 04:06 dominikj2

This looks like a bug in the state_dict() implementation. Sorry about that, we will fix as soon as possible.

dfalbel avatar Jun 11 '21 19:06 dfalbel