xla
xla copied to clipboard
Extracted subarrray's device is 'lazy' instead of 'xla' when using ellipsis extraction with XLA_DISABLE_FUNCTIONALIZATION=1
🐛 Bug
We use XLA_DISABLE_FUNCTIONALIZATION=1 in torch-xla 2.1 to workaround the trace slowdown issue (https://github.com/pytorch/xla/issues/6294). However, we are encountering a strange issue with the reproduction code in the next section.
The code has a buffer is registered using register_buffer inside a module. The forward method simply extracts a subset of the buffer. To provide flexiblity and support more dimensions, we use ellipsis in the extraction.
We noted that the use of ellipsis somehow caused the extracted buffer to be of device 'lazy' instead of 'xla', when the extraction size is the same size as the buffer size.
To Reproduce
Run the following code (save it as register_buffer.py) with XLA_DISABLE_FUNCTIONALIZATION=1:
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
class TestRegisterBuffCls(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.register_buffer('buffer', torch.arange(size, dtype=torch.float), persistent=False)
def forward(self, length):
return (self.buffer[:length, ...],
self.buffer[:length])
buffer_size = 32
mod = TestRegisterBuffCls(buffer_size).to(device)
for extract_length in range(0, 64, 8):
with_ellipsis, without_ellipsis = mod(extract_length)
print("buffer_size: ", buffer_size, "extract_length: ", extract_length)
print(with_ellipsis.shape, with_ellipsis.device)
print(without_ellipsis.shape, without_ellipsis.device)
assert(with_ellipsis.shape == without_ellipsis.shape), "Shapes don't match"
assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match"
XLA_DISABLE_FUNCTIONALIZATION=1 python register_buffer.py
buffer_size: 32 extract_length: 0
torch.Size([0]) xla:0
torch.Size([0]) xla:0
buffer_size: 32 extract_length: 8
torch.Size([8]) xla:0
torch.Size([8]) xla:0
buffer_size: 32 extract_length: 16
torch.Size([16]) xla:0
torch.Size([16]) xla:0
buffer_size: 32 extract_length: 24
torch.Size([24]) xla:0
torch.Size([24]) xla:0
buffer_size: 32 extract_length: 32
torch.Size([32]) lazy:0
torch.Size([32]) xla:0
Traceback (most recent call last):
File "register_buffer.py", line 23, in <module>
assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match"
AssertionError: Devices don't match
For sanity, run with XLA_DISABLE_FUNCTIONALIZATION=0:
XLA_DISABLE_FUNCTIONALIZATION=0 python register_buffer.py
buffer_size: 32 extract_length: 0
torch.Size([0]) xla:0
torch.Size([0]) xla:0
buffer_size: 32 extract_length: 8
torch.Size([8]) xla:0
torch.Size([8]) xla:0
buffer_size: 32 extract_length: 16
torch.Size([16]) xla:0
torch.Size([16]) xla:0
buffer_size: 32 extract_length: 24
torch.Size([24]) xla:0
torch.Size([24]) xla:0
buffer_size: 32 extract_length: 32
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size: 32 extract_length: 40
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size: 32 extract_length: 48
torch.Size([32]) xla:0
torch.Size([32]) xla:0
buffer_size: 32 extract_length: 56
torch.Size([32]) xla:0
torch.Size([32]) xla:0
Expected behavior
Device of the extracted array should be "xla:0", not "lazy:0", whether XLA_DISABLE_FUNCTIONALIZATION is set or not.
Environment
- Reproducible on XLA backend [CPU/TPU]: CPU/Neuron
- torch_xla version: 2.1
Additional context
Replacing the ellipsis with ":" helps, but make the code less general.
import torch
import torch_xla.core.xla_model as xm
device = xm.xla_device()
class TestRegisterBuffCls(torch.nn.Module):
def __init__(self, size):
super().__init__()
self.register_buffer('buffer', torch.zeros((size,100), dtype=torch.float), persistent=False)
def forward(self, length):
#return (self.buffer[:length, ...],
# self.buffer[:length])
return (self.buffer[:length, :],
self.buffer[:length])
buffer_size = 32
mod = TestRegisterBuffCls(buffer_size).to(device)
for extract_length in range(0, 64, 8):
with_ellipsis, without_ellipsis = mod(extract_length)
print("buffer_size: ", buffer_size, "extract_length: ", extract_length)
print(with_ellipsis.shape, with_ellipsis.device)
print(without_ellipsis.shape, without_ellipsis.device)
assert(with_ellipsis.shape == without_ellipsis.shape), "Shapes don't match"
assert(with_ellipsis.device == without_ellipsis.device), "Devices don't match"
Hmm This seems to be a bug in the functionization pass + LTC. My best guess is we call some upstream helper which default the device to lazy. @wonjoolee95 can you take this one?
Apologies for the late reply, thanks for bring this to our attention, @jeffhataws. I'll take a look at this sometime next week. We'll make sure this gets fixed by the 2.3 release at the latest.
Hi @wonjoolee95, I just want to follow up on this issue to see if there's a fix.