tqc_pytorch
tqc_pytorch copied to clipboard
Incompatible with PyTorch 2.0 - variable modified inplace
Getting the following error when trying to run the code with a (very simple) custom env using PyTorch 2.0.1:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 25]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead.
By setting torch.autograd.set_detect_anomaly(True)
, we get the following additional error trace:
UserWarning: Error detected in AddmmBackward0. Traceback of forward call that caused the error:
File "/Users/max/PycharmProjects/pythonProject1/tqc/tqc_testing.py", line 97, in <module>
trainer.train(replay_buffer, batch_size)
File "/Users/max/PycharmProjects/pythonProject1/tqc/trainer.py", line 61, in train
actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
File "/Users/max/PycharmProjects/pythonProject1/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/Users/max/PycharmProjects/pythonProject1/tqc/structures.py", line 100, in forward
quantiles = torch.stack(tuple(net(sa) for net in self.nets), dim=1)
When executing the same code with PyTorch 1.4, no errors are thrown and the algorithm trains as expected. Therefore, the culprit can't be the custom environment I'm using (a simple numerical solver, updating a numpy array at each time step).
My guess is it's related to this line actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
?
The problem is related to updating the critic and then the actor_loss which is based on critic.
One solution might be to move the actor_loss after the critic update, but I am not sure if this might make the training unstable:
`# --- Policy and alpha loss ---
new_action, log_pi = self.actor(state)
alpha_loss = -self.log_alpha * (log_pi + self.target_entropy).detach().mean()
# actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
# --- Update ---
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
actor_loss = (alpha * log_pi - self.critic(state, new_action).mean(2).mean(1, keepdim=True)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()`