subclass_zoo
subclass_zoo copied to clipboard
How to retain the grad of via __torch_dispatch__ for torch.Tensor method
I have a question, which might be very simple, but I have no idea how to fix it.
I am trying to subclass a torch.Tensor, and want to retain the grad of the original torch.Tensor method.
Here is my code:
import torch
from torch.utils._pytree import tree_map
class MyTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor):
return torch.Tensor.as_subclass(tensor, cls)
def __init__(self, tensor):
self.tensor = tensor
__torch_function__ = torch._C._disabled_torch_function_impl
def __repr__(self):
return self.__class__.__name__ +':\n'+ self.tensor.__repr__()
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
return t.tensor if isinstance(t, cls) else t
def wrap(t):
return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
def my_method(self):
return self.tensor.exp()
- Here is the result:
>>>x = MyTensor(torch.randn(3, requires_grad=True))
>>>x
MyTensor:
tensor([1.4196, 2.0849, 1.2102], requires_grad=True)
- The original method doesn't retain grad.
>>>x.exp()
MyTensor:
tensor([4.1355, 8.0442, 3.3543])
- Newly defined method retains grad:
>>>x.my_method()
tensor([4.1355, 8.0442, 3.3543], grad_fn=<ExpBackward0>)
if I use __torch_function__
, it can retain the grad. How can I retain the grad by using __torch_dispatch__
?
Thank you so much!
Hi,
When you do x.exp().sum().backward()
, do you expect x.grad
to be populated? Or x.tensor.grad
to be populated?
I will prefer x.grad, but I want to know how can I do that for both cases since they both might be useful in the future.
The high level idea is that you have to choose. Either x
gets autograd or x.tensor
. But it can't be both.
Here is an extension to your script to show how to do some if these things:
import torch
from torch.utils._pytree import tree_map
class MyTensorWithGrad(torch.Tensor):
@staticmethod
def __new__(cls, tensor, *, requires_grad=False):
assert tensor.requires_grad == False, "Only the wrapper should require gradients"
return torch.Tensor._make_subclass(cls, tensor, require_grad=requires_grad)
def __init__(self, tensor, *, requires_grad=False):
self.tensor = tensor
__torch_function__ = torch._C._disabled_torch_function_impl
def __repr__(self):
autograd_info = f"grad_fn={self.grad_fn}" if self.grad_fn else \
f"requires_grad={self.requires_grad}"
return f"{self.__class__.__name__}({self.tensor.__repr__()}, {autograd_info})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
return t.tensor if isinstance(t, cls) else t
def wrap(t):
return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t
return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))
def my_method(self):
# This method lives "above" autograd, should we should NOT access the ".tensor"
# attribute that is not differentiable.
# Use a custom Function to make this differentiable
class MyMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
# here it is ok to access tensor in a non-differentiable way!
ctx.save_for_backward(inp)
return inp.tensor.exp()
@staticmethod
def backward(ctx, gO):
inp, = ctx.saved_tensors
return inp * gO
return MyMethod.apply(self)
# if you don't want to have to write custom Function for everything,
# you can create a way to get the `.tensor` in a differentiable way!
# similar to .values() on sparse Tensor
def get_tensor_attr(self):
class MyAccessor(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
return inp.tensor
@staticmethod
def backward(ctx, gO):
return gO
return MyAccessor.apply(self)
def my_other_method(self):
return self.get_tensor_attr().exp()
x = MyTensorWithGrad(torch.randn(3), requires_grad=True)
print(x)
print("my_method")
print(x.my_method())
print("tensor")
print(x.tensor)
print("get_tensor_attr")
print(x.get_tensor_attr())
print("my_other_method")
print(x.my_other_method())
Thank you so much for your reply. Things are a little bit complicated on our side.
We are developing an open-source project PyPose using PyTorch and are subclassing torch.Tensor
to represent Lie Algebra and Lie Group.
One of our developers has asked a question 712 regarding using vmap
and jacrev
to compute Jacobian. Previously, we use torch_function
for subclassing, after seeing your reply, we are considering using torch_dispatch
, but we are not sure how can we handle it.
Basically, we have the following objective.
- we want
x
get autograd (we actually don't havex.tensor
). - use
vmap
andjacrev
to compute the Jacobian matrix. - the subclass constructor needs to retain the gradient, so in
__new__()
, we usetorch.Tensor.as_subclass(tensor, cls)
instead oftorch.Tensor._make_subclass(cls, tensor)
, since the inputtensor
can be the output of a neural network, which needs to track the grad for training. - Our current implementation using
torch_function
raises the error mentioned above. But when we trytorch_dispatch
, it seems that grad cannot be retained.
You can see our current implementation here.
Any suggestions for this? Thank you so much!
For your questions on our use case @zou3519 , you can also refer to the above link.
the subclass constructor needs to retain the gradient, so in new(), we use torch.Tensor.as_subclass(tensor, cls) instead of torch.Tensor._make_subclass(cls, tensor), since the input tensor can be the output of a neural network, which needs to track the grad for training.
This one can be done in a similar way as the "differentiable accessor" above but by doing a "differentiable constructor":
# Rest of the class from above
@staticmethod
def from_tensor(t):
class MyConst(torch.autograd.Function):
@staticmethod
def forward(ctx, t):
return MyTensorWithGrad(t.detach())
@staticmethod
def backward(ctx, gO):
return gO
return MyConst.apply(t)
inp = torch.rand(3, requires_grad=True)
x = MyTensorWithGrad.from_tensor(inp)
print(x)
use vmap and jacrev to compute the Jacobian matrix.
How big are these Tensors? You can use vanilla pytorch functions to get the jacobian as well, you can do so via torch.autograd.functional.jacobian
.