scratchai
scratchai copied to clipboard
Dependency issues in torch zero_gradient
Pytorch has removed the function zero_gradients. Redefining the same for forward compatibility.
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
References:
[1] https://github.com/pytorch/pytorch/blob/819d4b2b83fa632bf65d14f6af80a09e7476e87e/torch/autograd/gradcheck.py#L15 [2] https://discuss.pytorch.org/t/from-torch-autograd-gradcheck-import-zero-gradients/127462 [3] https://stackoverflow.com/questions/68419612/imorting-zero-gradients-from-torch-autograd-gradcheck