pytorch-soft-actor-critic icon indicating copy to clipboard operation
pytorch-soft-actor-critic copied to clipboard

Resume training

Open Tomeu7 opened this issue 3 years ago • 5 comments

Hello I am trying to use the SAC agent and resume training, to do that I do:

def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path):

  policy = torch.load(actor_path)
  self.alpha = policy['alpha'].detach().item()
  self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device)
  self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded

  self.policy.load_state_dict(policy['model_state_dict'])
  self.policy.train()
  self.critic.load_state_dict(torch.load(critic_path))
  self.critic.train()

  self.policy_optim.load_state_dict(torch.load(optimizer_actor_path))
  self.critic_optim.load_state_dict(torch.load(optimizer_critic_path))
  self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))

Is this correct? The loss explodes after resuming which is very strange.

Tomeu7 avatar Apr 29 '21 10:04 Tomeu7

That shouldn't happen. Will look into it. I might, require more detail on how you resume training. (Sorry for the late reply.)

pranz24 avatar May 25 '21 13:05 pranz24

Unfortunately I think this is still the case. When I reload saved parameters I get NaN in my loss function. It looks like it's coming from the QNetworks. After just the first layer, the network outputs NaNs. Also, it's not immediate. It takes anywhere from 30 steps to 500 steps depending on what I do. It is always deterministically failing though, if I don't change anything, it will fail every time with the same number of steps.

I've printed out the state_dict for each loaded param. The problem seems to be that the QNetworks (and critic optimizer) have NaN values in their serialized versions. What's strange is that I'm deserializing a checkpoint from a model that is still running. So the problem might not be the QNetworks, rather the serialization of them.

Edit: I just ran another test by saving/loading models and seeing if they were corrupted but couldn't find any such thing. That points the finger back at the QNetworks having some exploding gradient problem or something similar.

BennetLeff avatar Dec 11 '21 22:12 BennetLeff

Sorry I'm back without an answer but is it possible one of the issues that the alpha optimizer is not saved/loaded via checkpoints?

BennetLeff avatar Dec 21 '21 01:12 BennetLeff

I have a bit more time to look into this. Interestingly, the critic/Q Networks are what are filling up with NaN not the policy so it's probably not related to the temperature/alpha parameter.

BennetLeff avatar Jan 13 '22 15:01 BennetLeff

Hello I am trying to use the SAC agent and resume training, to do that I do:

def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path):

  policy = torch.load(actor_path)
  self.alpha = policy['alpha'].detach().item()
  self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device)
  self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded

  self.policy.load_state_dict(policy['model_state_dict'])
  self.policy.train()
  self.critic.load_state_dict(torch.load(critic_path))
  self.critic.train()

  self.policy_optim.load_state_dict(torch.load(optimizer_actor_path))
  self.critic_optim.load_state_dict(torch.load(optimizer_critic_path))
  self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))

Is this correct? The loss explodes after resuming which is very strange.

Did you reload the state-dicts of target networks? If not, this might be the reason for exploding loss. What's more, maybe the replay buffer needs storing/reloading as well : )

typoverflow avatar Feb 22 '22 10:02 typoverflow