xla icon indicating copy to clipboard operation
xla copied to clipboard

"RuntimeError: !at::functionalization::impl::isFunctionalTensor(t)" when running a DTensor test with functionalization on

Open jeffhataws opened this issue 5 months ago • 3 comments

🐛 Bug

When running the new DTensor placement test test/spmd/test_dtensor_integration3.py with functionalization on (default), I get the following error:

======================================================================
ERROR: test_xla_placement (__main__.DTensorIntegrationTest3)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 74, in test_xla_placement
    outputs_sharded = forward_pure(inputs, in_proj_weight, out_proj_weight)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 45, in forward_pure
    hidden = torch.matmul(hidden, in_proj_weight.T)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 195, in __torch_function__
    return super().__torch_function__(func, types, args, kwargs)
  File "/home/ubuntu/pytorch/torch/_tensor.py", line 1682, in __torch_function__
    ret = func(*args, **kwargs)
  File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 190, in __torch_dispatch__
    func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
  File "/home/ubuntu/pytorch/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
RuntimeError: !at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED at "/home/ubuntu/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":838, please report a bug to PyTorch. The composite op functionalization fallback expects its inputs all not to be functional tensors

To Reproduce

Steps to reproduce the behavior:

python test/spmd/test_dtensor_integration3.py

Passes with:

XLA_DISABLE_FUNCTIONALIZATION=1 python test/spmd/test_dtensor_integration3.py

Expected behavior

No crash with functionalization on

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: all
  • torch_xla version: v2.8, v2.9

Additional context

jeffhataws avatar Jul 11 '25 15:07 jeffhataws

@bfolie any idea what the error above could be? " The composite op functionalization fallback expects its inputs all not to be functional tensors".

jeffhataws avatar Jul 11 '25 17:07 jeffhataws

I think @amjames has run into this problem before when working on #9441.

That error is raised because we are calling operations, such as select_symint, which calls at::functionalization::functionalize_aten_op_symint internally in a context where we have not gone through the functionalization layer. This kind of thing can happen, specifically when we implement an XLA tensor subclass that implements __torch_dispatch__. In summary, functionalize_aten_op_symint expects its arguments to be already unwrapped from the functional tensor, which is what the functionalization layer does. Since we haven't gone through the functionalization layer, we get that error.

I believe this happens because the tensor subclass might not be a functional tensor itself (only its internal tensors). Which means that, by the time we get to __torch_dispatch__, the functionalization layer is unreachable due to this guard.

One way to quickly solve it is to call torch._from_functional_tensor(t), for all tensor arguments of said function. Note, however, that any code inside __torch_dispatch__, in this context, will be only dispatched to thing that are below the Python dispatch key (i.e. declared before).

ysiraichi avatar Jul 12 '25 23:07 ysiraichi

I can take a look at this it is the same issue I was dealing with.

amjames avatar Jul 15 '25 18:07 amjames