rlpyt
rlpyt copied to clipboard
Exception in logger with Pytorch >= 1.4
In Pytorch >= 1.4, grad_norm is a torch tensor (changed in https://github.com/pytorch/pytorch/pull/32020) and not a float, so the logger throws an exception here (values
is now a list of pytorch tensors): https://github.com/astooke/rlpyt/blob/a54cb5b1ee7b68d757aa0baa6a2786548419e366/rlpyt/utils/logging/logger.py#L457
To maintain backwards compatibility, an easy fix is to replace the append calls like https://github.com/astooke/rlpyt/blob/35af87255465b3644747294f7fd1ff6045dab910/rlpyt/algos/dqn/dqn.py#L184 with torch.tensor(grad_norm).item()
I am not sure if you want this to be fixed in logger though.
oh good idea on that bit of backward compatibility! thnx for posting. at some point will probably just move everything forward to 1.4 or 1.5, unless there is some reason not to?