xla icon indicating copy to clipboard operation
xla copied to clipboard

torch.linalg.lstsq issues on GPU/TPU

Open ttdd11 opened this issue 8 months ago • 33 comments

When using CPU as device, torch.linalg.lstsq works as expected. When using GPU/TPU, it fails via a dimension issue.

To reproduce:

import torch as torch

import torch_xla as xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

with xla.step():
    diff = torch.randn(8,3,1,requires_grad=True,device=device) 
    A = torch.randn(8,251,3,requires_grad=True,device=device)
    B = torch.randn(8,251,1,requires_grad=True,device=device)
    C = torch.randn(8,251,1,requires_grad=True,device=device)
    D = torch.randn(8,251,1,requires_grad=True,device=device)
    E = A*B
    F = C*D
    ref = torch.linalg.lstsq(A*B,C*D).solution
    loss = torch.mean(ref - diff)
    loss.backward()

With the errror:

E0408 16:13:38.043340 1246240 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: INVALID_ARGUMENT: Executable expected shape f32[3]{0} for argument 0 but got incompatible shape f32[8,3]{1,0}

Tested on using torch xla 2.4 on GPU and 2.6 and TPU. If you don't pass the tensors to the device, this runs as expected.

Any help with this would be appreciated.

ttdd11 avatar Apr 08 '25 20:04 ttdd11

I'm assuming that pulling this down to the cpu then pushing the results back to the device would cause slowdowns that would render TPU usage a bit tricky. Any assistance here would be greatly appreciated.

ttdd11 avatar Apr 10 '25 10:04 ttdd11

Thanks for sharing this issue. Can you share the line that triggers the exception? I assume the call to lstsq.

Can you also run it with E and F directly as

ref = torch.linalg.lstsq(E, F).solution

Can you materialize a tensor of shape E and F directly like this and try again:

ref = torch.linalg.lstsq(torch.randn_like(E), torch.randn_like(F)).solution

And finally, to break compilation, can you do this:

G = torch.randn_like(E)
H = torch.randn_like(F)
xm.mark_step()
ref = torch.linalg.lstsq(G, H).solution

thanks so much

yaoshiang avatar Apr 11 '25 16:04 yaoshiang

@yaoshiang thanks for the reply. Here are the results:

  1. Same as above
  2. No crash
  3. No crash
  4. I also tried this (but that also couldn't run the backward without crashing):
import torch_xla as xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

with xla.step():
    diff = torch.randn(8,3,1,requires_grad=True,device=device) 
    A = torch.randn(8,251,3,requires_grad=True,device=device)
    B = torch.randn(8,251,1,requires_grad=True,device=device)
    C = torch.randn(8,251,1,requires_grad=True,device=device)
    D = torch.randn(8,251,1,requires_grad=True,device=device)
    E = torch.mul(A,B)
    F = torch.mul(C,D)
    G = E.clone()
    H = F.clone()
    xm.mark_step()
    ref = torch.linalg.lstsq(G, H).solution
    #ref = torch.linalg.lstsq(torch.randn_like(E), torch.randn_like(F)).solution
    # ref = torch.cat([ref[:, :2], ref[:, 2:]], dim=1)
    #ref = torch.cat([ref[:, :2]*scale_rel_backproj, ref[:, 2:] * (scale_rel_backproj / scale2d)], dim=1)
    #ref =  torch.squeeze(ref, dim=-1)
    loss = torch.mean(ref - diff)
    loss.backward()

So it looks like it can handle the batch dim provided it's not multiplied before hand.

Any advice on how to manage this with the multiplications?

Thanks again.

ttdd11 avatar Apr 12 '25 12:04 ttdd11

Thanks so much. Looks like an issue with the longer series of ops during compilation. Can you try this as a potential workaround

with xla.step(): diff = torch.randn(8,3,1,requires_grad=True,device=device) A = torch.randn(8,251,3,requires_grad=True,device=device) B = torch.randn(8,251,1,requires_grad=True,device=device) C = torch.randn(8,251,1,requires_grad=True,device=device) D = torch.randn(8,251,1,requires_grad=True,device=device) E = AB F = CD + xm.mark_step() # no clones ref = torch.linalg.lstsq(AB,CD).solution loss = torch.mean(ref - diff) loss.backward()

yaoshiang avatar Apr 12 '25 23:04 yaoshiang

@yaoshiang thanks for the reply. Here is the code that I tried:

import torch as torch
import torch_xla as xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()

with xla.step():
    diff = torch.randn(8,3,1,requires_grad=True,device=device) 
    A = torch.randn(8,251,3,requires_grad=True,device=device)
    B = torch.randn(8,251,1,requires_grad=True,device=device)
    C = torch.randn(8,251,1,requires_grad=True,device=device)
    D = torch.randn(8,251,1,requires_grad=True,device=device)
    xm.mark_step()
    ref = torch.linalg.lstsq(A*B, C*D).solution
    loss = torch.mean(ref - diff)
    loss.backward()

which unfortunately yields the same result - the crash occurs on the backwards due to the same mismatch error.

ttdd11 avatar Apr 13 '25 11:04 ttdd11

Hi,

You can try this in torchax: (https://github.com/pytorch/xla/tree/master/torchax)

import torch as torch
import torchax
torchax.enable_globally()

device = 'jax'


diff = torch.randn(8,3,1,requires_grad=True,device=device) 
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
#TODO(https://github.com/pytorch/xla/issues/8983)
A.requires_grad = True
B.requires_grad = True
C.requires_grad = True
D.requires_grad = True
ref = torch.linalg.lstsq(A*B, C*D).solution
loss = torch.mean(ref - diff)
loss.backward()
print(A.grad)

Even better approach (this approach doesn't need to set requires_grad explicitly):

def func(A, B, C, D):
  ref = torch.linalg.lstsq(A*B, C*D).solution
  loss = torch.mean(ref - diff)
  return loss

grad_fn = torchax.interop.jax_value_and_grad(func)

loss, gradients = grad_fn(A, B, C, D)
print(gradients[0])

qihqi avatar Apr 16 '25 04:04 qihqi

@qihqi thank you again! I can give this a try today. Other than the install instructions you provided in that link, are there any other differences to using jax as the device?

ttdd11 avatar Apr 16 '25 10:04 ttdd11

I wasn't able to get the same error on GPU.

Traceback (most recent call last):
  File "8953.py", line 16, in <module>
    ref = torch.linalg.lstsq(A*B,C*D).solution
RuntimeError: Trying to resize storage that is not resizable

PyTorch/XLA: 9ec626e653336f47221636839c9d31110e5f0a80

ysiraichi avatar Apr 16 '25 13:04 ysiraichi

I ran it on pt-cuda and it worked successfully.

yaoshiang avatar Apr 16 '25 14:04 yaoshiang

@yaoshiang what version of torch are you using? I tested this on 2.4 GPU and 2.6 TPU with the same error. Also - what variant of the code are you running?

Thanks for the help.

ttdd11 avatar Apr 16 '25 18:04 ttdd11

@qihqi thank you again! I can give this a try today. Other than the install instructions you provided in that link, are there any other differences to using jax as the device?

Yes, so the speed you get is jax-eager speed. If you want to get faster speed with compilation, you need to wrap the function with torchax.interop.jax_jit which is equivalent to jax.jit.

More details here: https://docs.jax.dev/en/latest/faq.html#is-jax-faster-than-numpy

If you want to train a model, the recommendation would be to use this fucntion https://github.com/pytorch/xla/blob/master/torchax/torchax/train.py#L15 to generate a train step and call it repeatedly; instead of the regard loss.backward(); optimzer.step() incantation.

qihqi avatar Apr 16 '25 22:04 qihqi

@qihqi I'm currently running this eager anyways - I think the nature of a least square optimization requires a ton of materialization of intermediate values so I'm not sure if getting away from eager is possible. This function works as is in a for loop on the batch and without running eager it's 10x times slower.

Regarding the training of the model: if I use the code above - you are recommending that I adjust the entire train paradigm we are using to use that in jax or is that only if we want graph compilation execution instead of eager?

ttdd11 avatar Apr 16 '25 22:04 ttdd11

@yaoshiang what version of torch are you using? I tested this on 2.4 GPU and 2.6 TPU with the same error. Also - what variant of the code are you running?

Thanks for the help.

Sorry for the confusion - I ran it on an Nvidia T4 with the latest torch, no torch_xla, and it ran. If you got different results, let me know.

yaoshiang avatar Apr 16 '25 23:04 yaoshiang

Regarding the training of the model: if I use the code above - you are recommending that I adjust the entire train paradigm we are using to use that in jax or is that only if we want graph compilation execution instead of eager?

Only if you use graph compilation.

To be clear, when using torchax, both eager mode and compiled mode will use Jax to execute under-the-hood. When using graph compilation, you want to compile the thing that you will launch repeatedly without changing input shapes, the "train_step" concept fits this well -- a train step is one forward + one backward + one optimizer update (NOTE, if you use torch.compile on GPU, this is also the recommendation: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial_.html).

With torchax, you can do both torch style train loop OR jax style train loop. The jax style train loop will be more performant because it's more compiler friendly.

qihqi avatar Apr 17 '25 03:04 qihqi

@qihqi Thanks again for the help here - do you know of any other work around that doesn't require us to change the training paradigm to use jax? Currently - we can run the least squares in a for loop - but the speed is affected so greatly that it's impractical to train.

If you know of anything else we could do - like pull to cpu, mark the step early, etc.. Or is it the case that the only way to make this happen is to use jax entirely.

Thanks again.

ttdd11 avatar Apr 29 '25 15:04 ttdd11

@qihqi @yaoshiang we are probably going to more towards torchax as this bug affects our training too much. In following the examples, we have modified our code to a torchax paradigm. Unfortunately we are getting faults when we enumerate the dataloder:

https://symbolize.stripped_domain/r/?trace=5e952f6ae95c,7927c544251f&map= *** SIGSEGV STACK OVERFLOW (see go/cppstackoverflow) received by PID 44501 (TID 44501) on cpu 1; stack trace: *** PC: @ 0x5e952f6ae95c (unknown) (unknown) @ 0x792673da9a01 1888 (unknown) @ 0x7927c5442520 1863137424 (unknown) @ 0x5e952faa6540 (unknown) (unknown) https://symbolize.stripped_domain/r/?trace=5e952f6ae95c,792673da9a00,7927c544251f,5e952faa653f&map= E0430 20:55:00.679375 44501 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked. E0430 20:55:00.679387 44501 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start. E0430 20:55:00.679391 44501 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec. E0430 20:55:00.679395 44501 coredump_hook.cc:396] RAW: Sending fingerprint to remote end. E0430 20:55:00.679412 44501 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory E0430 20:55:00.679417 44501 coredump_hook.cc:457] RAW: Dumping core locally. E0430 20:55:00.737472 44501 process_state.cc:806] RAW: Raising signal 11 with default behavior Segmentation fault (core dumped)

I've seen this before when the versions are incorrect. We are using torch/torch_xla 2.6.0+cpu.cxx11.abi on tpu v6e nodes. My guess is that the libtpu is incompatible with torch_xla and jax which is leading to this fault. Do you have any installation instructions for v6e nodes that we could follow that would help with this?

Thank you again.

ttdd11 avatar Apr 30 '25 21:04 ttdd11

@qihqi @yaoshiang pulled together a script that uses ones of the examples (found here: https://github.com/pytorch/xla/blob/master/torchax/examples/mnist_tpu.ipynb):

To create the node, we used this command:

gcloud alpha compute tpus tpu-vm create test-tpu-vm --zone=us-east5-b --accelerator-type=v6e-1 --version=v2-alpha-tpuv6e

As a starting point - we installed the libraries according to this page https://github.com/pytorch/xla/tree/master/torchax:

pip install torch --index-url https://download.pytorch.org/whl/cpu pip install -U jax[tpu] pip install torchax pip install torchvision

which leads to this when importing torchvision:

Traceback (most recent call last): File "/home/user/trainer_new2/test_ax.py", line 2, in import torchvision File "/home/user/.local/lib/python3.10/site-packages/torchvision/init.py", line 10, in from torchvision import _meta_registrations, datasets, io, models, ops, transforms, utils # usort:skip File "/home/user/.local/lib/python3.10/site-packages/torchvision/_meta_registrations.py", line 164, in def meta_nms(dets, scores, iou_threshold): File "/home/user/.local/lib/python3.10/site-packages/torch/library.py", line 1023, in register use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) File "/home/user/.local/lib/python3.10/site-packages/torch/library.py", line 214, in _register_fake handle = entry.fake_impl.register(func_to_register, source) File "/home/user/.local/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 31, in register if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"): RuntimeError: operator torchvision::nms does not exist

I've seen this before - and it's related to cpu/gpu version of torchvision (and it seems to interact with xla somehow).

To correct this, we install torch_xla[tpu] and torchvision cpu, which results in this:

File "/home/user/trainer_new2/test_ax.py", line 82, in loss = train_step(sharded_inputs, sharded_labels) File "/home/user/.local/lib/python3.10/site-packages/torchax/distributed.py", line 238, in inner new_states, outputs = _jit_fn(jax_states, args) File "/home/user/.local/lib/python3.10/site-packages/torchax/interop.py", line 172, in call_jax res: JaxValue = jax_func(*args, **kwargs) File "/home/user/.local/lib/python3.10/site-packages/torchax/interop.py", line 179, in call_torch res: TorchValue = torch_func(*args, **kwargs) File "/home/user/.local/lib/python3.10/site-packages/torchax/distributed.py", line 232, in _jit_fn outputs = func(*args) File "/home/user/trainer_new2/test_ax.py", line 66, in train_step optimizer.zero_grad() File "/home/user/.local/lib/python3.10/site-packages/torch/_compile.py", line 51, in inner return disable_fn(*args, **kwargs) File "/home/user/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) File "/home/user/.local/lib/python3.10/site-packages/torch/optim/optimizer.py", line 962, in zero_grad with torch.autograd.profiler.record_function(self._zero_grad_profile_name): File "/home/user/.local/lib/python3.10/site-packages/torch/autograd/profiler.py", line 771, in enter self.record = torch.ops.profiler._record_function_enter_new( File "/home/user/.local/lib/python3.10/site-packages/torch/_ops.py", line 1158, in call return self._op(*args, **(kwargs or {})) File "/home/user/.local/lib/python3.10/site-packages/torchax/tensor.py", line 249, in torch_function return self.env.dispatch(func, types, args, kwargs) File "/home/user/.local/lib/python3.10/site-packages/torchax/tensor.py", line 488, in dispatch ..... (the error is much longer, just removed the center portion that is repeated). File "/home/user/.local/lib/python3.10/site-packages/torchax/tensor.py", line 249, in torch_function return self.env.dispatch(func, types, args, kwargs) File "/home/user/.local/lib/python3.10/site-packages/torchax/tensor.py", line 470, in dispatch args, kwargs = torch_pytree.tree_map_only( File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1333, in tree_map_only return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1143, in tree_map leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1055, in tree_flatten treespec = helper(tree, leaves) File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1051, in helper subspecs = [helper(child, leaves) for child in children] File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1051, in subspecs = [helper(child, leaves) for child in children] File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1051, in helper subspecs = [helper(child, leaves) for child in children] File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1051, in subspecs = [helper(child, leaves) for child in children] File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 1042, in helper if _is_leaf(node, is_leaf=is_leaf): File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 837, in _is_leaf return (is_leaf is not None and is_leaf(tree)) or _get_node_type( File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 830, in _get_node_type if _is_namedtuple_instance(tree): File "/home/user/.local/lib/python3.10/site-packages/torch/utils/_pytree.py", line 821, in _is_namedtuple_instance if len(bases) != 1 or bases[0] != tuple: RecursionError: maximum recursion depth exceeded while calling a Python object

For this, typically we set the recursion limit to be higher using: import sys sys.setrecursionlimit(1000000000)

but that didn't seem to have an effect.

So restarting with new installs that we have used historically:

pip install jax[tpu]==0.4.38 pip install torchax[tpu] pip install torch==2.6.0+cpu.cxx11.abi https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl 'torch_xla[tpu]' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://download.pytorch.org/whl/torch pip3 install torchvision==0.21 --index-url https://download.pytorch.org/whl/cpu

python3 test_ax.py WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. /home/user/.local/lib/python3.10/site-packages/torchvision/io/image.py:14: UserWarning: Failed to load image Python extension: '/home/user/.local/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb'If you don't plan on using image functionality from torchvision.io, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have libjpeg or libpng installed before building torchvision from source? warn( /home/user/.local/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:82: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled") warnings.warn( /home/user/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:354: UserWarning: Device capability of jax unspecified, assuming cpu and cuda. Please specify it via the devices argument of register_backend. warnings.warn( [Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=0, data=[[-0.02377432 0.01916441 0.01878371 ... 0.02399771 -0.0104323 0.00296436] [-0.01832963 -0.00767883 0.02187536 ... -0.00146879 -0.0352551 0.02864969] [ 0.01945502 0.03494517 0.00816102 ... -0.00581495 -0.00392314 0.02278726] ... [ 0.01982327 0.00928994 0.03490133 ... -0.00178648 -0.02440884 0.0337343 ] [-0.02462376 -0.01744491 -0.00030861 ... 0.01886704 -0.02126323 0.01498662] [-0.00435257 0.03461109 0.03238511 ... 0.00440394 -0.01724172 -0.00723013]])] Epoch 0 https://symbolize.stripped_domain/r/?trace=7c225ca2c091,7c225c44251f&map= *** SIGSEGV STACK OVERFLOW (see go/cppstackoverflow) received by PID 8917 (TID 8917) on cpu 22; stack trace: *** PC: @ 0x7c225ca2c091 (unknown) (unknown) @ 0x7c219d1a9a01 1888 (unknown) @ 0x7c225c442520 2048913128 (unknown) @ ... and at least 1 more frames https://symbolize.stripped_domain/r/?trace=7c225ca2c091,7c219d1a9a00,7c225c44251f&map= E0501 12:45:40.956276 8917 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked. E0501 12:45:40.956301 8917 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start. E0501 12:45:40.956306 8917 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec. E0501 12:45:40.956310 8917 coredump_hook.cc:396] RAW: Sending fingerprint to remote end. E0501 12:45:40.956328 8917 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory E0501 12:45:40.956333 8917 coredump_hook.cc:457] RAW: Dumping core locally. E0501 12:45:40.998825 8917 process_state.cc:806] RAW: Raising signal 11 with default behavior Segmentation fault (core dumped)

and this is similar to the error that I'm getting on the dataloader - here is the script we are using, which is basically copied from the examples:

import torch
import torch_xla as xla
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import jax
import torchax
import pprint

import sys
sys.setrecursionlimit(1000000000)

train_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307,), (0.3081,))]))

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    drop_last=True,
    shuffle=True)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=32,
    drop_last=True,
    shuffle=False)

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 512),
    nn.ReLU(),
    nn.Linear(512, 10)
)



# The TPU core count will vary depending on your environment.
jax.device_count()
ddp_model = torchax.distributed.DistributedDataParallel(model)
example_param = next(ddp_model.parameters())

pprint.pprint(example_param._elem.addressable_shards)

example_images, _ = next(iter(train_loader))
example_images.shape

sharded_example_images = ddp_model.shard_input(example_images)
sharded_example_images.shape

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)

@ddp_model.jit_step
def train_step(sharded_inputs, sharded_labels):
  optimizer.zero_grad()
  outputs = ddp_model(sharded_inputs)
  loss = loss_fn(outputs, sharded_labels)
  loss.backward()
  optimizer.step()

  return loss

for epoch in range(10):
  running_loss = 0

  print('Epoch', epoch)
  for i, data in enumerate(train_loader):
      inputs, labels = data
      # Distribute the batch across all TPU cores
      sharded_inputs, sharded_labels = ddp_model.shard_input(inputs), ddp_model.shard_input(labels)
      loss = train_step(sharded_inputs, sharded_labels)

      if i % 100 == 0:
          print('  batch {} loss: {}'.format(i, loss.item()))
          running_loss = 0.

Sorry that this is long winded - really looking for some insight here to get this going.

Thank you again.

ttdd11 avatar May 01 '25 13:05 ttdd11

BTW are you in touch with Ela Jamali?

yaoshiang avatar May 01 '25 15:05 yaoshiang

@yaoshiang yes

ttdd11 avatar May 01 '25 15:05 ttdd11

I can take a look at this.

I was able to repro the crash on May 2 nightly of torch_xla, and have confirmed that the equivalent code doesn't crash on CPU.

tengyifei avatar May 05 '25 22:05 tengyifei

Hi @ttdd11,

I will further debug what is going on with the environment.

Meanwhile, we landed a new feature assume_pure in nightly and I tried it it can solve the lstsq in torchxla (which might be easier as going to torchax is a bigger change.

First replace torch and torch_xla with nightly:

pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
# Edit `cp310-cp310` to fit your desired Python version as needed
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

Then, this bit of code:



import torch as torch

import torch_xla as xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.assume_pure as assume_pure

device = xm.xla_device()

@assume_pure
def lstsq(x, y):
    return torch.linalg.lstsq(x, y).solution
    

with xla.step():
    diff = torch.randn(8,3,1,requires_grad=True,device=device) 
    A = torch.randn(8,251,3,requires_grad=True,device=device)
    B = torch.randn(8,251,1,requires_grad=True,device=device)
    C = torch.randn(8,251,1,requires_grad=True,device=device)
    D = torch.randn(8,251,1,requires_grad=True,device=device)
    E = A*B
    F = C*D
    ref = lstsq(E, F)
    loss = torch.mean(ref - diff)
    loss.backward()
    print(A.grad)

Gave

tensor([[[-1.1189e-05, -8.5187e-06, -9.6764e-06],
         [-1.1518e-06, -1.5361e-06, -1.7471e-06],
         [-5.4305e-06, -3.1749e-05, -3.6162e-05],
         ...,
         [-1.2812e-04, -6.6418e-05, -7.5332e-05],
         [-3.5903e-06,  1.8855e-06,  2.1583e-06],
         [-1.0248e-06,  9.3509e-08,  1.0934e-07]],

        [[ 2.0873e-05,  1.6173e-05,  1.2404e-05],
         [ 1.0282e-04,  7.8961e-05,  9.3233e-05],
         [-2.4210e-05, -1.8589e-05, -2.2111e-05],
         ...,
         [ 2.5044e-05,  1.9359e-05,  1.7007e-05],
         [-1.6464e-05, -1.2604e-05, -1.6739e-05],
         [ 1.2520e-06,  9.6230e-07,  1.0983e-06]],

        [[-9.2494e-05, -8.7698e-05, -1.1561e-04],
         [-3.2149e-05, -3.0673e-05, -4.0150e-05],
         [ 8.2621e-05,  7.7064e-05,  1.0350e-04],
         ...,
         [ 6.0826e-06,  5.9933e-06,  7.5614e-06],
         [ 1.4669e-05,  1.3372e-05,  1.8434e-05],
         [-2.2157e-04, -2.3709e-04, -2.7201e-04]],

        ...,

        [[-1.0429e-05, -1.9833e-05, -7.0310e-07],
         [ 1.1624e-04,  4.2236e-05,  1.4492e-04],
         [-1.5279e-04, -1.4217e-04, -1.2406e-04],
         ...,
         [-7.5765e-05, -7.5498e-05, -5.7689e-05],
         [-5.0151e-05, -4.9623e-05, -3.8456e-05],
         [ 4.0996e-05,  4.1791e-05,  3.0495e-05]],

        [[-4.6437e-05, -4.3407e-05, -5.4165e-05],
         [ 2.2402e-06,  9.0005e-06, -2.7746e-06],
         [-1.2474e-05, -1.4912e-05, -1.2014e-05],
         ...,
         [-7.2032e-05, -6.6849e-05, -8.4396e-05],
         [ 7.0603e-05,  2.4920e-04, -6.0555e-05],
         [ 1.1091e-04,  2.4513e-04,  1.9026e-05]],

        [[ 5.7307e-05,  4.5501e-05,  5.6519e-05],
         [-1.1200e-04, -1.0720e-04, -1.3477e-04],
         [-4.0535e-05, -3.6372e-05, -4.5547e-05],
         ...,
         [-5.3763e-06, -5.6098e-06, -7.0862e-06],
         [ 1.9833e-05,  1.5748e-05,  1.9561e-05],
         [ 2.1044e-05,  1.5666e-05,  1.9367e-05]]], device='xla:0')

Please try this instead and let me know if that works for you or not, and sorry for going back and forth with different suggestions.

qihqi avatar May 08 '25 04:05 qihqi

Pass to @qihqi

tengyifei avatar May 08 '25 21:05 tengyifei

@qihqi @tengyifei thanks for the help here.

We were able to give this a try and at first it looked promising.

For starters - the dataloader setup needed some modifications. We have had problems with them in the past and we have used this: self.loader = torch.utils.data.DataLoader(self.dataset, batch_size=None, shuffle=False, drop_last=False, num_workers=1,persistent_workers=True, pin_memory = True,collate_fn=self.collate_fn)

however with this install - we are getting device errors with anything other than number_workers = 0 which eliminates pin_memory and persistent workers. This may be the cause of the next issue but it's hard to know.

With the number of workers set to 0 - we run the same train (that has run for 100+ epochs with torch xla 2.6), and after a certain number of steps that isn't reproducible (~1000 so ~6000000 images) we get a memory error:

if not torch.isnan(loss).any() and torch.is_nonzero(loss):

RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 126.56M. That was not possible. There are 46.46M free.; (2x2x0_HBM0)

If we remove that line - the loss goes to zero instead of hitting the memory error.

I'm not entirely sure where to go from here to trouble shoot this. The only two changes are the data loader and code shown above.

Thanks again.

Edit:

Ran on 2.6 with the same data loader and it works as expected. Will run tomorrow without the assume pure to see if that's the issue or if it's something else with nightly torch xla.

ttdd11 avatar May 18 '25 19:05 ttdd11

@ttdd11 wondering does this issue show up in single device, multi-processing (torch_xla.spawn()), or SPMD (https://pytorch.org/xla/master/spmd.html) execution mode?

tengyifei avatar May 20 '25 17:05 tengyifei

I'll try to reproduce these issues by adapting the script in https://github.com/pytorch/xla/issues/8953#issuecomment-2844826841.

In the meantime, @ttdd11 if you have a reproducer script for one or more of the issues encountered, that would be helpful.

tengyifei avatar May 20 '25 17:05 tengyifei

@tengyifei apologies for my delay here - we will put together an example where the dataloader issue is occurring and hopefully one where the memory becomes an issue.

ttdd11 avatar May 23 '25 19:05 ttdd11

@tengyifei Apologies for the delay. Unfortunately we don't have a minimal example that replicates all issues - but here is a starting point for some. I believe that the issues we are seeing are related to the dataloader. Here is a short example of one issues with nightly dataloader:

import torch as torch
import torch_xla

#defines for sizing
batch = 48
channel = 3
width = 480
height = 480

# Define a custom dataset
from torch.utils.data import IterableDataset, DataLoader
class InfiniteDataset(IterableDataset):
    def __init__(self):
        gg = 0

    def __iter__(self):
        while True:
            yield [torch.randn(channel,width,height),torch.randn(251,3),torch.randn(3,1)]

#model class
from torchvision.models.efficientnet import efficientnet_v2_l
from torch import nn
class Backbone(nn.Module):
    """
    use SpatialAtt + ChannelAtt
    """
    def __init__(self):
        super().__init__()
        self.backbone = efficientnet_v2_l()
        self.conv = nn.Conv1d(1000,251*3,  kernel_size=1, stride=1,padding=0,bias=True)
    def forward(self, x):
        b = x.shape[0]
        y = self.backbone(x)
        y = y.unsqueeze(dim = 2)
        y = self.conv(y)
        y = y.view([b,251,3])
        return y

#train update 
import numpy as np
def _train_update(step, loss):
    if torch.is_tensor(loss):
       loss = np.mean(loss.detach().cpu().numpy())
    update_data = ['Training','Step={}'.format(step),'Loss={:.5f}'.format(loss)]
    print('|', ' '.join(item for item in update_data if item), flush=True)


    
def train_routine():
    #Note - these need to be included in this routine or this error appears
    #    raise RuntimeError('Runtime is already initialized. Do not use the XLA 
    #    RuntimeError: Runtime is already initialized. Do not use the XLA device before calling xmp.spawn.
    import torch_xla.core.xla_model as xm
    from torch_xla.amp import autocast
    from torch_xla.amp import syncfree
    import torch_xla.distributed.parallel_loader as pl
    
    #pure least squares
    from torch_xla.experimental.assume_pure import assume_pure
    @assume_pure
    def lstsq(x, y):
        return torch.linalg.lstsq(x, y).solution
        
    device = xm.xla_device()
    print("device")
    model = Backbone()
    model = model.to(device=device)
    model.train()
    optimizer = syncfree.Adam(
    model.parameters(),
    lr=0.000001,
    weight_decay=0.0000001)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10000000, eta_min=0.000001, last_epoch=-1)    
    
    def train_loop_fn(loader):
        step = 0
        for s, (feed,label,diff) in enumerate(loader):
            step += 1
            with autocast(xm.xla_device()):
                out = model(feed)
                ref = lstsq(out, label)
                loss = torch.mean(ref - diff)
            if not torch.isnan(loss).any() and torch.is_nonzero(loss):
                loss.backward()
                xm.optimizer_step(optimizer)
                lr_scheduler.step()
                if step % 20 == 0:
                    xm.add_step_closure(_train_update,args=(step, loss))
            else:
                print('nan or zero',flush = True)
   
    infinite_dataset = InfiniteDataset()
    #this line causes the issue
    #dataloader = DataLoader(infinite_dataset, batch_size=batch,shuffle=False, drop_last=False, num_workers=2,persistent_workers=True, pin_memory = True)
    #removing persistant workers and pinning memory allows it to continue
    dataloader = DataLoader(infinite_dataset, batch_size=batch,shuffle=False, drop_last=False, num_workers=2)

    train_loader = pl.MpDeviceLoader(dataloader, device)
    while True:
        train_loop_fn(train_loader)

def _mp_fn(index):
    train_routine()
    sys.exit(21)

def main(args):
    torch_xla.launch(_mp_fn, args=())

if __name__ == '__main__':

    import sys
    main(sys.argv[1:])

In 2.7 (see commented line above) - we are unable to pin memory/have persistent workers.

Historically we have used 1 worker, pinning memory, and persistent workers as it allowed the trains to run uninterrupted. Even in torch 2.6 - if we use a worker number greater than 1 the trains stop at some point that isn't deterministic - the dataloader is the bottleneck when using 1 worker.

I'm not entirely sure how to see what is going on here - any debugging advice would be greatly appreciated.

EDIT: The issue with the dataloader stopping occurs when n workers >= 1, with 0 it does not occur.

ttdd11 avatar Jun 03 '25 10:06 ttdd11

@junjieqian

tengyifei avatar Jun 06 '25 23:06 tengyifei

@tengyifei I have a minimal example that brings in a bit more of our process - however we are finding that the use_pure is quite a bit slower than when we first started to experiment. What would be the best method for us to test this? I would like to verify that it's a change to packages and not our process as we finalize this.

ttdd11 avatar Jun 11 '25 00:06 ttdd11

Hi @ttdd11 , thanks for the update! I take this task now and am happy to support.

quite a bit slower Would you mind sharing more? I assume you use different package versions, if so, what different versions? Also, it would be really appreciated if you can share the minimal example with setups to run. I will try to repro locally and further investigate it.

junjieqian avatar Jun 11 '25 03:06 junjieqian