"RuntimeError: !at::functionalization::impl::isFunctionalTensor(t)" when running a DTensor test with functionalization on
🐛 Bug
When running the new DTensor placement test test/spmd/test_dtensor_integration3.py with functionalization on (default), I get the following error:
======================================================================
ERROR: test_xla_placement (__main__.DTensorIntegrationTest3)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 74, in test_xla_placement
outputs_sharded = forward_pure(inputs, in_proj_weight, out_proj_weight)
File "/home/ubuntu/pt2.8_sws/pytorch/xla/test/spmd/test_dtensor_integration3.py", line 45, in forward_pure
hidden = torch.matmul(hidden, in_proj_weight.T)
File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 195, in __torch_function__
return super().__torch_function__(func, types, args, kwargs)
File "/home/ubuntu/pytorch/torch/_tensor.py", line 1682, in __torch_function__
ret = func(*args, **kwargs)
File "/home/ubuntu/pt2.8_sws/pytorch/xla/torch_xla/distributed/spmd/xla_sharded_tensor.py", line 190, in __torch_dispatch__
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
File "/home/ubuntu/pytorch/torch/_ops.py", line 829, in __call__
return self._op(*args, **kwargs)
RuntimeError: !at::functionalization::impl::isFunctionalTensor(t) INTERNAL ASSERT FAILED at "/home/ubuntu/pytorch/aten/src/ATen/FunctionalTensorWrapper.cpp":838, please report a bug to PyTorch. The composite op functionalization fallback expects its inputs all not to be functional tensors
To Reproduce
Steps to reproduce the behavior:
python test/spmd/test_dtensor_integration3.py
Passes with:
XLA_DISABLE_FUNCTIONALIZATION=1 python test/spmd/test_dtensor_integration3.py
Expected behavior
No crash with functionalization on
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: all
- torch_xla version: v2.8, v2.9
Additional context
@bfolie any idea what the error above could be? " The composite op functionalization fallback expects its inputs all not to be functional tensors".
I think @amjames has run into this problem before when working on #9441.
That error is raised because we are calling operations, such as select_symint, which calls at::functionalization::functionalize_aten_op_symint internally in a context where we have not gone through the functionalization layer. This kind of thing can happen, specifically when we implement an XLA tensor subclass that implements __torch_dispatch__. In summary, functionalize_aten_op_symint expects its arguments to be already unwrapped from the functional tensor, which is what the functionalization layer does. Since we haven't gone through the functionalization layer, we get that error.
I believe this happens because the tensor subclass might not be a functional tensor itself (only its internal tensors). Which means that, by the time we get to __torch_dispatch__, the functionalization layer is unreachable due to this guard.
One way to quickly solve it is to call torch._from_functional_tensor(t), for all tensor arguments of said function. Note, however, that any code inside __torch_dispatch__, in this context, will be only dispatched to thing that are below the Python dispatch key (i.e. declared before).
I can take a look at this it is the same issue I was dealing with.