functorch icon indicating copy to clipboard operation
functorch copied to clipboard

make_fx fails with `jacfwd` (when used with torch.add(Tensor, Scalar))

Open kshitij12345 opened this issue 3 years ago • 1 comments

import torch
import functorch

dtype = torch.float32
device = torch.device('cpu')

def foo(x):
    return x + 1.0

x = torch.tensor([[0.0]], dtype=dtype, device=device)

functorch.make_fx(functorch.vmap(foo))(x)  # Works
functorch.make_fx(functorch.jacrev(foo))(x)  # Works
functorch.make_fx(functorch.jacfwd(foo))(x)  # Fails

Error Message:

Traceback (most recent call last):
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 31, in <module>
    functorch.make_fx(functorch.jacfwd(foo))(x)  # Fails
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 683, in wrapped
    t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 441, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/_symbolic_trace.py", line 739, in trace
    (self.create_arg(fn(*args)),),
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 457, in wrapped
    out = f(*tensors)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 996, in wrapper_fn
    results = vmap(push_jvp, randomness=randomness)(basis)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 362, in wrapped
    return _flat_vmap(
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 489, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 989, in push_jvp
    output = _jvp_with_argnums(func, args, basis, argnums=argnums, has_aux=has_aux)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/vmap.py", line 35, in fn
    return f(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_functorch/eager_transforms.py", line 837, in _jvp_with_argnums
    result_duals = func(*duals)
  File "/home/kshiteej/Pytorch/pytorch_functorch/test/test_scratch.py", line 26, in foo
    return x + 1.0
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 259, in proxy_call
    r = func.decompose(*args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_ops.py", line 307, in decompose
    return self._op_dk(dk, *args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 483, in __torch_dispatch__
    return self.inner_torch_dispatch(func, types, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 508, in inner_torch_dispatch
    out = proxy_call(self, func, args, kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 393, in proxy_call
    track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 206, in track_tensor_tree
    wrap_with_proxy(inner_res, proxy_res, constant)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 185, in wrap_with_proxy
    set_meta(proxy, e)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/fx/experimental/proxy_tensor.py", line 149, in set_meta
    proxy.node.meta['val'] = torch.empty_strided(val.shape, val.stride(), device=val.device, dtype=val.dtype)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 878, in __torch_dispatch__
    op_impl_out = op_impl(self, func, *args, **kwargs)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 325, in constructors
    return FakeTensor(fake_mode, r, out_device)
  File "/home/kshiteej/.conda/envs/pytorch-cuda-dev/lib/python3.9/site-packages/torch/_subclasses/fake_tensor.py", line 560, in __init__
    assert device.type != "meta"
AssertionError

kshitij12345 avatar Dec 07 '22 12:12 kshitij12345

Could be related to https://github.com/pytorch/pytorch/issues/90065

zou3519 avatar Dec 07 '22 19:12 zou3519