redef zgrad for backward and forward compatibility
Pytorch has removed the function zero_gradients. Redefining the same for forward compatibility.
def zero_gradients(x):
if isinstance(x, torch.Tensor):
if x.grad is not None:
x.grad.detach_()
x.grad.zero_()
elif isinstance(x, collections.abc.Iterable):
for elem in x:
zero_gradients(elem)
References:
[1]https://github.com/pytorch/pytorch/blob/819d4b2b83fa632bf65d14f6af80a09e7476e87e/torch/autograd/gradcheck.py#L15 [2] https://discuss.pytorch.org/t/from-torch-autograd-gradcheck-import-zero-gradients/127462 [3] https://stackoverflow.com/questions/68419612/imorting-zero-gradients-from-torch-autograd-gradcheck
Hi Arunava, I was wondering if you could merge this. There is a course at CMU where we are using your library and it is causing conflicts with the latest version of pytorch, prompting students to re-install an older version of python just for the sake of the assignment. It will be helpful if you could merge the PR.