xla icon indicating copy to clipboard operation
xla copied to clipboard

split on second dimension of 2D array not working with XLA_DISABLE_FUNCTIONALIZATION=1

Open jeffhataws opened this issue 11 months ago • 11 comments

🐛 Bug

When running a small example to split 2D array in the second dimension, the resulting tensors don't have the expected data. The results are different between CPU and XLA-CPU.

To Reproduce

Run. the follow test:

import torch
import torch_xla

a_golden = torch.arange(12, device="cpu").reshape(3, 4)
b_golden, c_golden = a_golden.split([3, 1], dim=-1)
a_xla = torch.arange(12, device="xla").reshape(3, 4)
b_xla, c_xla = a_xla.split([3, 1], dim=-1)

print("a original:", a_golden)
print("b golden :", b_golden)
print("b xla :", b_xla)
print("c golden :", c_golden)
print("c xla :", c_xla)

torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
torch.testing.assert_close(c_golden, c_xla.cpu(), rtol=0, atol=0)

Save as test_split.py and run:

PJRT_DEVICE=CPU python test_split.py
WARNING:torch_neuron:RANK environment variable is not set, defaulting to 0.
WARNING:torch_neuron:LOCAL RANK environment variable is not set to 0, defaulting to 0.
a original: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
b golden : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])
b xla : tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
        [ 7],
        [11]])
c xla : tensor([[3],
        [4],
        [5]], device='xla:0')
Traceback (most recent call last):
  File "/home/ubuntu/transformers/examples/pytorch/text-classification/test_split.py", line 15, in <module>
    torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
  File "/home/ubuntu/aws_neuron_venv_pt26/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)

Expected behavior

XLA CPU result should match CPU results

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CPU
  • torch_xla version: 2.6 (also 2.1)

Additional context

jeffhataws avatar Jan 28 '25 17:01 jeffhataws

Strange. This issue is not reproducible on GPU. It is confusing.

jeffhataws avatar Jan 29 '25 01:01 jeffhataws

It turns out disabling functionalization (XLA_DISABLE_FUNCTIONALIZATION=1) is causing the error. It is XLA_DISABLE_FUNCTIONALIZATION=1 by default in Neuron environment. To resolve this, please set XLA_DISABLE_FUNCTIONALIZATION=0.

cpu_venv_py310) ubuntu@ip-10-3-190-82:~$ XLA_DISABLE_FUNCTIONALIZATION=0 PJRT_DEVICE=CPU python test_split.py 
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
a original: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
b golden : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])
b xla : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]], device='xla:0')
c golden : tensor([[ 3],
        [ 7],
        [11]])
c xla : tensor([[ 3],
        [ 7],
        [11]], device='xla:0')
(cpu_venv_py310) ubuntu@ip-10-3-190-82:~$ XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CPU python test_split.py 
WARNING:root:MASTER_ADDR environment variable is not set, defaulting to localhost
a original: tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
b golden : tensor([[ 0,  1,  2],
        [ 4,  5,  6],
        [ 8,  9, 10]])
b xla : tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]], device='xla:0')
c golden : tensor([[ 3],
        [ 7],
        [11]])
c xla : tensor([[3],
        [4],
        [5]], device='xla:0')
Traceback (most recent call last):
  File "/home/ubuntu/test_split.py", line 16, in <module>
    torch.testing.assert_close(b_golden, b_xla.cpu(), rtol=0, atol=0)
  File "/home/ubuntu/cpu_venv_py310/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 6 / 9 (66.7%)
Greatest absolute difference: 2 at index (2, 0)
Greatest relative difference: 0.3333333432674408 at index (1, 0)

jeffhataws avatar Feb 03 '25 05:02 jeffhataws

Reopen since XLA_DISABLE_FUNCTIONALIZATION=1 is still used by Neuron.

jeffhataws avatar Feb 03 '25 16:02 jeffhataws

Thank you for submitting this issue. I was able to reproduce it on: 225c65bd7b00ca5162a9979dac3b118e3f00fbf7 I will take a look into this.

ysiraichi avatar Feb 05 '25 15:02 ysiraichi

@ysiraichi do we have an update on this bug?

cc @amjames

miladm avatar Mar 13 '25 16:03 miladm

Sorry. I still haven't had the time to look into this issue.

ysiraichi avatar Mar 14 '25 12:03 ysiraichi

As noted in https://github.com/aws-neuron/aws-neuron-sdk/issues/1140 , using tensor_split work-around this issue.

jeffhataws avatar Apr 30 '25 17:04 jeffhataws

@ysiraichi will you able to look at this issue sometime for v2.9?

jeffhataws avatar Jul 25 '25 20:07 jeffhataws

Ah, sorry. I ended up never going back to it. I will take a look at it this week.

ysiraichi avatar Jul 28 '25 12:07 ysiraichi

The primary problem appears to be that split_with_sizes relies on as_strided, but our current lowering implementation for as_strided is too limited. Specifically, it doesn't correctly handle as_strided arguments for a non-contiguous view. Here are a few solutions I can think of:

  1. Use split_with_sizes_copy() instead: Call the copy() variant, e.g. torch.split_with_sizes_copy(a, [3, 1], dim=-1)
  2. Raise an error: Easy to do, but doesn't fix the error.
  3. Implement a better lowering for as_strided: Non-trivial. A similar approach was previously implemented for the functionalization path, which might provide a useful reference.
  4. Re-write as_strided tracing without functionalization in terms of take: It would require adapting our view machinery if not implemented directly within it.
  5. Change split_with_sizes tracing: Instead of using its CompositeExplicitAutograd kernel, we could create a custom kernel. While the lowering itself might not be overly difficult, correctly handling the view part would be challenging.

If (1) is enough to solve this issue, I think that's the easiest fix. That said, even if (1) is sufficient, (2) should be implemented so as to stop returning an incorrect result.

In case a proper fix is necessary, out of the options left, I think we should go for (5). That's because the lowering is straight forward, and arguably better than the implementation of the proper as_strided lowering.

If we end up going for (3) or (4), we probably should move the index computation in the as_strided tracing to AsStrided::Lower, which would benefit both with/without functionalization.

@jeffhataws Let me know if (1) is enough, or if you need tensor.split() to work, i.e. fix (3), (4), or (5).

ysiraichi avatar Jul 29 '25 23:07 ysiraichi

Hi @ysiraichi sorry for not following up sooner. Let's try add # 2 if possible and point users to using # 1.

jeffhataws avatar Oct 07 '25 22:10 jeffhataws