xla icon indicating copy to clipboard operation
xla copied to clipboard

Extracted subarrray's device is 'lazy' instead of 'xla' when using ellipsis extraction with XLA_DISABLE_FUNCTIONALIZATION=1

Open jeffhataws opened this issue 1 year ago • 3 comments

🐛 Bug

We use XLA_DISABLE_FUNCTIONALIZATION=1 in torch-xla 2.1 to workaround the trace slowdown issue (https://github.com/pytorch/xla/issues/6294). However, we are encountering a strange issue with the reproduction code in the next section.

The code has a buffer is registered using register_buffer inside a module. The forward method simply extracts a subset of the buffer. To provide flexiblity and support more dimensions, we use ellipsis in the extraction.

We noted that the use of ellipsis somehow caused the extracted buffer to be of device 'lazy' instead of 'xla', when the extraction size is the same size as the buffer size.

To Reproduce

Run the following code (save it as register_buffer.py) with XLA_DISABLE_FUNCTIONALIZATION=1:

import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

class TestRegisterBuffCls(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.register_buffer('buffer', torch.arange(size, dtype=torch.float), persistent=False)

    def forward(self, length):
        return (self.buffer[:length, ...],
                self.buffer[:length])

buffer_size = 32
mod = TestRegisterBuffCls(buffer_size).to(device)

for extract_length in range(0, 64, 8):
    with_ellipsis, without_ellipsis = mod(extract_length)
    print("buffer_size: ", buffer_size, "extract_length: ", extract_length)
    print(with_ellipsis.shape, with_ellipsis.device)
    print(without_ellipsis.shape, without_ellipsis.device)
    assert(with_ellipsis.shape == without_ellipsis.shape), "Shapes don't match"
    assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match"

XLA_DISABLE_FUNCTIONALIZATION=1 python register_buffer.py

buffer_size:  32 extract_length:  0
torch.Size([0]) xla:0               
torch.Size([0]) xla:0 
buffer_size:  32 extract_length:  8
torch.Size([8]) xla:0               
torch.Size([8]) xla:0 
buffer_size:  32 extract_length:  16
torch.Size([16]) xla:0              
torch.Size([16]) xla:0
buffer_size:  32 extract_length:  24
torch.Size([24]) xla:0                                                                                                                                               
torch.Size([24]) xla:0                                                                                                                                               
buffer_size:  32 extract_length:  32
torch.Size([32]) lazy:0
torch.Size([32]) xla:0
Traceback (most recent call last):
  File "register_buffer.py", line 23, in <module>
    assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match" 
AssertionError: Devices don't match

For sanity, run with XLA_DISABLE_FUNCTIONALIZATION=0:

XLA_DISABLE_FUNCTIONALIZATION=0 python register_buffer.py

buffer_size:  32 extract_length:  0
torch.Size([0]) xla:0
torch.Size([0]) xla:0
buffer_size:  32 extract_length:  8
torch.Size([8]) xla:0
torch.Size([8]) xla:0
buffer_size:  32 extract_length:  16
torch.Size([16]) xla:0
torch.Size([16]) xla:0
buffer_size:  32 extract_length:  24
torch.Size([24]) xla:0
torch.Size([24]) xla:0
buffer_size:  32 extract_length:  32
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size:  32 extract_length:  40
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size:  32 extract_length:  48
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size:  32 extract_length:  56
torch.Size([32]) xla:0
torch.Size([32]) xla:0

Expected behavior

Device of the extracted array should be "xla:0", not "lazy:0", whether XLA_DISABLE_FUNCTIONALIZATION is set or not.

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU/Neuron
  • torch_xla version: 2.1

Additional context

jeffhataws avatar Jan 29 '24 07:01 jeffhataws

Replacing the ellipsis with ":" helps, but make the code less general.

import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()

class TestRegisterBuffCls(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.register_buffer('buffer', torch.zeros((size,100), dtype=torch.float), persistent=False)

    def forward(self, length):
        #return (self.buffer[:length, ...],
        #        self.buffer[:length])
        return (self.buffer[:length, :],
                self.buffer[:length])

buffer_size = 32
mod = TestRegisterBuffCls(buffer_size).to(device)

for extract_length in range(0, 64, 8):
    with_ellipsis, without_ellipsis = mod(extract_length)
    print("buffer_size: ", buffer_size, "extract_length: ", extract_length)
    print(with_ellipsis.shape, with_ellipsis.device)
    print(without_ellipsis.shape, without_ellipsis.device)
    assert(with_ellipsis.shape == without_ellipsis.shape), "Shapes don't match"
    assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match"

jeffhataws avatar Jan 29 '24 17:01 jeffhataws

Hmm This seems to be a bug in the functionization pass + LTC. My best guess is we call some upstream helper which default the device to lazy. @wonjoolee95 can you take this one?

JackCaoG avatar Jan 29 '24 21:01 JackCaoG

Apologies for the late reply, thanks for bring this to our attention, @jeffhataws. I'll take a look at this sometime next week. We'll make sure this gets fixed by the 2.3 release at the latest.

wonjoolee95 avatar Feb 22 '24 06:02 wonjoolee95

Hi @wonjoolee95, I just want to follow up on this issue to see if there's a fix.

jeffhataws avatar May 08 '24 05:05 jeffhataws