functorch
functorch copied to clipboard
`grad(factory_op)` breaks with mixed cpu/cuda inputs
This seems like a minor issue, but the following codes breaks:
def foo(x):
z = torch.zeros(1) # factory func allocating on cpu
z.copy_(x) # cuda_tensor.copy_(cpu_tensor)
return z.sum()
x = torch.tensor(3.14, device='cuda')
grad(foo))(x)
And fails with the error NotImplementedError: Cannot access storage of TensorWrapper
.
I'm not sure whether the cpu_tensor.copy_(cuda_tensor)
is actually valid, but it seems to work without the grad
call.