pytorch-soft-actor-critic
pytorch-soft-actor-critic copied to clipboard
Resume training
Hello I am trying to use the SAC agent and resume training, to do that I do:
def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path):
policy = torch.load(actor_path)
self.alpha = policy['alpha'].detach().item()
self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device)
self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded
self.policy.load_state_dict(policy['model_state_dict'])
self.policy.train()
self.critic.load_state_dict(torch.load(critic_path))
self.critic.train()
self.policy_optim.load_state_dict(torch.load(optimizer_actor_path))
self.critic_optim.load_state_dict(torch.load(optimizer_critic_path))
self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))
Is this correct? The loss explodes after resuming which is very strange.
That shouldn't happen. Will look into it. I might, require more detail on how you resume training. (Sorry for the late reply.)
Unfortunately I think this is still the case. When I reload saved parameters I get NaN
in my loss function. It looks like it's coming from the QNetworks. After just the first layer, the network outputs NaN
s. Also, it's not immediate. It takes anywhere from 30 steps to 500 steps depending on what I do. It is always deterministically failing though, if I don't change anything, it will fail every time with the same number of steps.
I've printed out the state_dict
for each loaded param. The problem seems to be that the QNetworks (and critic optimizer) have NaN
values in their serialized versions. What's strange is that I'm deserializing a checkpoint from a model that is still running. So the problem might not be the QNetworks, rather the serialization of them.
Edit: I just ran another test by saving/loading models and seeing if they were corrupted but couldn't find any such thing. That points the finger back at the QNetworks having some exploding gradient problem or something similar.
Sorry I'm back without an answer but is it possible one of the issues that the alpha optimizer is not saved/loaded via checkpoints?
I have a bit more time to look into this. Interestingly, the critic/Q Networks are what are filling up with NaN
not the policy so it's probably not related to the temperature/alpha parameter.
Hello I am trying to use the SAC agent and resume training, to do that I do:
def load_model(actor_path, critic_path, optimizer_actor_path, optimizer_critic_path, optimizer_alpha_path): policy = torch.load(actor_path) self.alpha = policy['alpha'].detach().item() self.log_alpha = torch.tensor([policy['log_alpha'].detach().item()], requires_grad=True, device=self.device) self.alpha_optim = Adam([self.log_alpha], lr=self.lr) # I had to recreate alpha optim with the new log_alpha loaded self.policy.load_state_dict(policy['model_state_dict']) self.policy.train() self.critic.load_state_dict(torch.load(critic_path)) self.critic.train() self.policy_optim.load_state_dict(torch.load(optimizer_actor_path)) self.critic_optim.load_state_dict(torch.load(optimizer_critic_path)) self.alpha_optim.load_state_dict(torch.load(optimizer_alpha_path))
Is this correct? The loss explodes after resuming which is very strange.
Did you reload the state-dicts of target networks? If not, this might be the reason for exploding loss. What's more, maybe the replay buffer needs storing/reloading as well : )