split on second dimension of 2D array not working with XLA_DISABLE_FUNCTIONALIZATION=1
🐛 Bug
When running a small example to split 2D array in the second dimension, the resulting tensors don't have the expected data. The results are different between CPU and XLA-CPU.
To Reproduce
Run. the follow test:
import torch
import torch_xla
a_golden = torch.arange(12, device="cpu").reshape(3, 4)
b_golden, c_golden = a_golden.split([3, 1], dim=-1)
a_xla = torch.arange(12, device="xla").reshape(3, 4)
b_xla, c_xla = a_xla.split([3, 1], dim=-1)
print("a original:", a_golden)
print("b golden :", b_golden)
print("b xla :", b_xla)
print("c golden :", c_golden)
print("c xla :", c_xla)
torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
torch.testing.assert_close(c_golden, c_xla.cpu(), rtol=0, atol=0)
Save as test_split.py and run:
PJRT_DEVICE=CPU python test_split.py
WARNING:torch_neuron:RANK environment variable is not set, defaulting to 0.
WARNING:torch_neuron:LOCAL RANK environment variable is not set to 0, defaulting to 0.
a original: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
b golden : tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
b xla : tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
[ 7],
[11]])
c xla : tensor([[3],
[4],
[5]], device='xla:0')
Traceback (most recent call last):
File "/home/ubuntu/transformers/examples/pytorch/text-classification/test_split.py", line 15, in <module>
torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
File "/home/ubuntu/aws_neuron_venv_pt26/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!
Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)
Expected behavior
XLA CPU result should match CPU results
Environment
- Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
- torch_xla version: 2.6 (also 2.1)
Additional context
Strange. This issue is not reproducible on GPU. It is confusing.
It turns out disabling functionalization (XLA_DISABLE_FUNCTIONALIZATION=1) is causing the error. It is XLA_DISABLE_FUNCTIONALIZATION=1 by default in Neuron environment. To resolve this, please set XLA_DISABLE_FUNCTIONALIZATION=0.
cpu_venv_py310) ubuntu@ip-10-3-190-82:~$ XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CPU python test_split.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
a original: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
b golden : tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
b xla : tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]], device='xla:0')
c golden : tensor([[ 3],
[ 7],
[11]])
c xla : tensor([[ 3],
[ 7],
[11]], device='xla:0')
(cpu_venv_py310) ubuntu@ip-10-3-190-82:~$ XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CPU python test_split.py
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
a original: tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
b golden : tensor([[ 0, 1, 2],
[ 4, 5, 6],
[ 8, 9, 10]])
b xla : tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
[ 7],
[11]])
c xla : tensor([[3],
[4],
[5]], device='xla:0')
Traceback (most recent call last):
File "/home/ubuntu/test_split.py", line 16, in <module>
torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
File "/home/ubuntu/cpu_venv_py310/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!
Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)
Reopen since XLA_DISABLE_FUNCTIONALIZATION=1 is still used by Neuron.
Thank you for submitting this issue. I was able to reproduce it on: 225c65bd7b00ca5162a9979dac3b118e3f00fbf7 I will take a look into this.
@ysiraichi do we have an update on this bug?
cc @amjames
Sorry. I still haven't had the time to look into this issue.
As noted in https://github.com/aws-neuron/aws-neuron-sdk/issues/1140 , using tensor_split work-around this issue.
@ysiraichi will you able to look at this issue sometime for v2.9?
Ah, sorry. I ended up never going back to it. I will take a look at it this week.
The primary problem appears to be that split_with_sizes relies on as_strided, but our current lowering implementation for as_strided is too limited. Specifically, it doesn't correctly handle as_strided arguments for a non-contiguous view. Here are a few solutions I can think of:
- Use
split_with_sizes_copy()instead: Call thecopy()variant, e.g.torch.split_with_sizes_copy(a, [3, 1], dim=-1) - Raise an error: Easy to do, but doesn't fix the error.
- Implement a better lowering for
as_strided: Non-trivial. A similar approach was previously implemented for the functionalization path, which might provide a useful reference. - Re-write
as_stridedtracing without functionalization in terms oftake: It would require adapting our view machinery if not implemented directly within it. - Change
split_with_sizestracing: Instead of using itsCompositeExplicitAutogradkernel, we could create a custom kernel. While the lowering itself might not be overly difficult, correctly handling the view part would be challenging.
If (1) is enough to solve this issue, I think that's the easiest fix. That said, even if (1) is sufficient, (2) should be implemented so as to stop returning an incorrect result.
In case a proper fix is necessary, out of the options left, I think we should go for (5). That's because the lowering is straight forward, and arguably better than the implementation of the proper as_strided lowering.
If we end up going for (3) or (4), we probably should move the index computation in the as_strided tracing to AsStrided::Lower, which would benefit both with/without functionalization.
@jeffhataws Let me know if (1) is enough, or if you need tensor.split() to work, i.e. fix (3), (4), or (5).
Hi @ysiraichi sorry for not following up sooner. Let's try add # 2 if possible and point users to using # 1.