torch.linalg.lstsq issues on GPU/TPU
When using CPU as device, torch.linalg.lstsq works as expected. When using GPU/TPU, it fails via a dimension issue.
To reproduce:
import torch as torch
import torch_xla as xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
with xla.step():
diff = torch.randn(8,3,1,requires_grad=True,device=device)
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
E = A*B
F = C*D
ref = torch.linalg.lstsq(A*B,C*D).solution
loss = torch.mean(ref - diff)
loss.backward()
With the errror:
E0408 16:13:38.043340 1246240 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: INVALID_ARGUMENT: Executable expected shape f32[3]{0} for argument 0 but got incompatible shape f32[8,3]{1,0}
Tested on using torch xla 2.4 on GPU and 2.6 and TPU. If you don't pass the tensors to the device, this runs as expected.
Any help with this would be appreciated.
I'm assuming that pulling this down to the cpu then pushing the results back to the device would cause slowdowns that would render TPU usage a bit tricky. Any assistance here would be greatly appreciated.
Thanks for sharing this issue. Can you share the line that triggers the exception? I assume the call to lstsq.
Can you also run it with E and F directly as
ref = torch.linalg.lstsq(E, F).solution
Can you materialize a tensor of shape E and F directly like this and try again:
ref = torch.linalg.lstsq(torch.randn_like(E), torch.randn_like(F)).solution
And finally, to break compilation, can you do this:
G = torch.randn_like(E)
H = torch.randn_like(F)
xm.mark_step()
ref = torch.linalg.lstsq(G, H).solution
thanks so much
@yaoshiang thanks for the reply. Here are the results:
- Same as above
- No crash
- No crash
- I also tried this (but that also couldn't run the backward without crashing):
import torch_xla as xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
with xla.step():
diff = torch.randn(8,3,1,requires_grad=True,device=device)
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
E = torch.mul(A,B)
F = torch.mul(C,D)
G = E.clone()
H = F.clone()
xm.mark_step()
ref = torch.linalg.lstsq(G, H).solution
#ref = torch.linalg.lstsq(torch.randn_like(E), torch.randn_like(F)).solution
# ref = torch.cat([ref[:, :2], ref[:, 2:]], dim=1)
#ref = torch.cat([ref[:, :2]*scale_rel_backproj, ref[:, 2:] * (scale_rel_backproj / scale2d)], dim=1)
#ref = torch.squeeze(ref, dim=-1)
loss = torch.mean(ref - diff)
loss.backward()
So it looks like it can handle the batch dim provided it's not multiplied before hand.
Any advice on how to manage this with the multiplications?
Thanks again.
Thanks so much. Looks like an issue with the longer series of ops during compilation. Can you try this as a potential workaround
with xla.step(): diff = torch.randn(8,3,1,requires_grad=True,device=device) A = torch.randn(8,251,3,requires_grad=True,device=device) B = torch.randn(8,251,1,requires_grad=True,device=device) C = torch.randn(8,251,1,requires_grad=True,device=device) D = torch.randn(8,251,1,requires_grad=True,device=device) E = AB F = CD + xm.mark_step() # no clones ref = torch.linalg.lstsq(AB,CD).solution loss = torch.mean(ref - diff) loss.backward()
@yaoshiang thanks for the reply. Here is the code that I tried:
import torch as torch
import torch_xla as xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
with xla.step():
diff = torch.randn(8,3,1,requires_grad=True,device=device)
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
xm.mark_step()
ref = torch.linalg.lstsq(A*B, C*D).solution
loss = torch.mean(ref - diff)
loss.backward()
which unfortunately yields the same result - the crash occurs on the backwards due to the same mismatch error.
Hi,
You can try this in torchax: (https://github.com/pytorch/xla/tree/master/torchax)
import torch as torch
import torchax
torchax.enable_globally()
device = 'jax'
diff = torch.randn(8,3,1,requires_grad=True,device=device)
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
#TODO(https://github.com/pytorch/xla/issues/8983)
A.requires_grad = True
B.requires_grad = True
C.requires_grad = True
D.requires_grad = True
ref = torch.linalg.lstsq(A*B, C*D).solution
loss = torch.mean(ref - diff)
loss.backward()
print(A.grad)
Even better approach (this approach doesn't need to set requires_grad explicitly):
def func(A, B, C, D):
ref = torch.linalg.lstsq(A*B, C*D).solution
loss = torch.mean(ref - diff)
return loss
grad_fn = torchax.interop.jax_value_and_grad(func)
loss, gradients = grad_fn(A, B, C, D)
print(gradients[0])
@qihqi thank you again! I can give this a try today. Other than the install instructions you provided in that link, are there any other differences to using jax as the device?
I wasn't able to get the same error on GPU.
Traceback (most recent call last):
File "8953.py", line 16, in <module>
ref = torch.linalg.lstsq(A*B,C*D).solution
RuntimeError: Trying to resize storage that is not resizable
PyTorch/XLA: 9ec626e653336f47221636839c9d31110e5f0a80
I ran it on pt-cuda and it worked successfully.
@yaoshiang what version of torch are you using? I tested this on 2.4 GPU and 2.6 TPU with the same error. Also - what variant of the code are you running?
Thanks for the help.
@qihqi thank you again! I can give this a try today. Other than the install instructions you provided in that link, are there any other differences to using jax as the device?
Yes, so the speed you get is jax-eager speed. If you want to get faster speed with compilation, you need to wrap the function with torchax.interop.jax_jit which is equivalent to jax.jit.
More details here: https://docs.jax.dev/en/latest/faq.html#is-jax-faster-than-numpy
If you want to train a model, the recommendation would be to use this fucntion https://github.com/pytorch/xla/blob/master/torchax/torchax/train.py#L15 to generate a train step and call it repeatedly; instead of the regard loss.backward(); optimzer.step() incantation.
@qihqi I'm currently running this eager anyways - I think the nature of a least square optimization requires a ton of materialization of intermediate values so I'm not sure if getting away from eager is possible. This function works as is in a for loop on the batch and without running eager it's 10x times slower.
Regarding the training of the model: if I use the code above - you are recommending that I adjust the entire train paradigm we are using to use that in jax or is that only if we want graph compilation execution instead of eager?
@yaoshiang what version of torch are you using? I tested this on 2.4 GPU and 2.6 TPU with the same error. Also - what variant of the code are you running?
Thanks for the help.
Sorry for the confusion - I ran it on an Nvidia T4 with the latest torch, no torch_xla, and it ran. If you got different results, let me know.
Regarding the training of the model: if I use the code above - you are recommending that I adjust the entire train paradigm we are using to use that in jax or is that only if we want graph compilation execution instead of eager?
Only if you use graph compilation.
To be clear, when using torchax, both eager mode and compiled mode will use Jax to execute under-the-hood.
When using graph compilation, you want to compile the thing that you will launch repeatedly without changing input shapes, the "train_step" concept fits this well -- a train step is one forward + one backward + one optimizer update (NOTE, if you use torch.compile on GPU, this is also the recommendation: https://pytorch.org/tutorials/intermediate/torch_compile_tutorial_.html).
With torchax, you can do both torch style train loop OR jax style train loop. The jax style train loop will be more performant because it's more compiler friendly.
@qihqi Thanks again for the help here - do you know of any other work around that doesn't require us to change the training paradigm to use jax? Currently - we can run the least squares in a for loop - but the speed is affected so greatly that it's impractical to train.
If you know of anything else we could do - like pull to cpu, mark the step early, etc.. Or is it the case that the only way to make this happen is to use jax entirely.
Thanks again.
@qihqi @yaoshiang we are probably going to more towards torchax as this bug affects our training too much. In following the examples, we have modified our code to a torchax paradigm. Unfortunately we are getting faults when we enumerate the dataloder:
https://symbolize.stripped_domain/r/?trace=5e952f6ae95c,7927c544251f&map= *** SIGSEGV STACK OVERFLOW (see go/cppstackoverflow) received by PID 44501 (TID 44501) on cpu 1; stack trace: *** PC: @ 0x5e952f6ae95c (unknown) (unknown) @ 0x792673da9a01 1888 (unknown) @ 0x7927c5442520 1863137424 (unknown) @ 0x5e952faa6540 (unknown) (unknown) https://symbolize.stripped_domain/r/?trace=5e952f6ae95c,792673da9a00,7927c544251f,5e952faa653f&map= E0430 20:55:00.679375 44501 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked. E0430 20:55:00.679387 44501 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start. E0430 20:55:00.679391 44501 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec. E0430 20:55:00.679395 44501 coredump_hook.cc:396] RAW: Sending fingerprint to remote end. E0430 20:55:00.679412 44501 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory E0430 20:55:00.679417 44501 coredump_hook.cc:457] RAW: Dumping core locally. E0430 20:55:00.737472 44501 process_state.cc:806] RAW: Raising signal 11 with default behavior Segmentation fault (core dumped)
I've seen this before when the versions are incorrect. We are using torch/torch_xla 2.6.0+cpu.cxx11.abi on tpu v6e nodes. My guess is that the libtpu is incompatible with torch_xla and jax which is leading to this fault. Do you have any installation instructions for v6e nodes that we could follow that would help with this?
Thank you again.
@qihqi @yaoshiang pulled together a script that uses ones of the examples (found here: https://github.com/pytorch/xla/blob/master/torchax/examples/mnist_tpu.ipynb):
To create the node, we used this command:
gcloud alpha compute tpus tpu-vm create test-tpu-vm --zone=us-east5-b --accelerator-type=v6e-1 --version=v2-alpha-tpuv6e
As a starting point - we installed the libraries according to this page https://github.com/pytorch/xla/tree/master/torchax:
pip install torch --index-url https://download.pytorch.org/whl/cpu pip install -U jax[tpu] pip install torchax pip install torchvision
which leads to this when importing torchvision:
Traceback (most recent call last):
File "/home/user/trainer_new2/test_ax.py", line 2, in
I've seen this before - and it's related to cpu/gpu version of torchvision (and it seems to interact with xla somehow).
To correct this, we install torch_xla[tpu] and torchvision cpu, which results in this:
File "/home/user/trainer_new2/test_ax.py", line 82, in
For this, typically we set the recursion limit to be higher using: import sys sys.setrecursionlimit(1000000000)
but that didn't seem to have an effect.
So restarting with new installs that we have used historically:
pip install jax[tpu]==0.4.38 pip install torchax[tpu] pip install torch==2.6.0+cpu.cxx11.abi https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl 'torch_xla[tpu]' -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html -f https://download.pytorch.org/whl/torch pip3 install torchvision==0.21 --index-url https://download.pytorch.org/whl/cpu
python3 test_ax.py
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
/home/user/.local/lib/python3.10/site-packages/torchvision/io/image.py:14: UserWarning: Failed to load image Python extension: '/home/user/.local/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb'If you don't plan on using image functionality from torchvision.io, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have libjpeg or libpng installed before building torchvision from source?
warn(
/home/user/.local/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:82: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled")
warnings.warn(
/home/user/.local/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:354: UserWarning: Device capability of jax unspecified, assuming cpu and cuda. Please specify it via the devices argument of register_backend.
warnings.warn(
[Shard(device=TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), index=(slice(None, None, None), slice(None, None, None)), replica_id=0, data=[[-0.02377432 0.01916441 0.01878371 ... 0.02399771 -0.0104323
0.00296436]
[-0.01832963 -0.00767883 0.02187536 ... -0.00146879 -0.0352551
0.02864969]
[ 0.01945502 0.03494517 0.00816102 ... -0.00581495 -0.00392314
0.02278726]
...
[ 0.01982327 0.00928994 0.03490133 ... -0.00178648 -0.02440884
0.0337343 ]
[-0.02462376 -0.01744491 -0.00030861 ... 0.01886704 -0.02126323
0.01498662]
[-0.00435257 0.03461109 0.03238511 ... 0.00440394 -0.01724172
-0.00723013]])]
Epoch 0
https://symbolize.stripped_domain/r/?trace=7c225ca2c091,7c225c44251f&map=
*** SIGSEGV STACK OVERFLOW (see go/cppstackoverflow) received by PID 8917 (TID 8917) on cpu 22; stack trace: ***
PC: @ 0x7c225ca2c091 (unknown) (unknown)
@ 0x7c219d1a9a01 1888 (unknown)
@ 0x7c225c442520 2048913128 (unknown)
@ ... and at least 1 more frames
https://symbolize.stripped_domain/r/?trace=7c225ca2c091,7c219d1a9a00,7c225c44251f&map=
E0501 12:45:40.956276 8917 coredump_hook.cc:301] RAW: Remote crash data gathering hook invoked.
E0501 12:45:40.956301 8917 coredump_hook.cc:340] RAW: Skipping coredump since rlimit was 0 at process start.
E0501 12:45:40.956306 8917 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0501 12:45:40.956310 8917 coredump_hook.cc:396] RAW: Sending fingerprint to remote end.
E0501 12:45:40.956328 8917 coredump_hook.cc:405] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0501 12:45:40.956333 8917 coredump_hook.cc:457] RAW: Dumping core locally.
E0501 12:45:40.998825 8917 process_state.cc:806] RAW: Raising signal 11 with default behavior
Segmentation fault (core dumped)
and this is similar to the error that I'm getting on the dataloader - here is the script we are using, which is basically copied from the examples:
import torch
import torch_xla as xla
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import jax
import torchax
import pprint
import sys
sys.setrecursionlimit(1000000000)
train_dataset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
test_dataset = torchvision.datasets.MNIST(
root='./data',
train=False,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=32,
drop_last=True,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=32,
drop_last=True,
shuffle=False)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
# The TPU core count will vary depending on your environment.
jax.device_count()
ddp_model = torchax.distributed.DistributedDataParallel(model)
example_param = next(ddp_model.parameters())
pprint.pprint(example_param._elem.addressable_shards)
example_images, _ = next(iter(train_loader))
example_images.shape
sharded_example_images = ddp_model.shard_input(example_images)
sharded_example_images.shape
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(ddp_model.parameters(), lr=0.001, momentum=0.9)
@ddp_model.jit_step
def train_step(sharded_inputs, sharded_labels):
optimizer.zero_grad()
outputs = ddp_model(sharded_inputs)
loss = loss_fn(outputs, sharded_labels)
loss.backward()
optimizer.step()
return loss
for epoch in range(10):
running_loss = 0
print('Epoch', epoch)
for i, data in enumerate(train_loader):
inputs, labels = data
# Distribute the batch across all TPU cores
sharded_inputs, sharded_labels = ddp_model.shard_input(inputs), ddp_model.shard_input(labels)
loss = train_step(sharded_inputs, sharded_labels)
if i % 100 == 0:
print(' batch {} loss: {}'.format(i, loss.item()))
running_loss = 0.
Sorry that this is long winded - really looking for some insight here to get this going.
Thank you again.
BTW are you in touch with Ela Jamali?
@yaoshiang yes
I can take a look at this.
I was able to repro the crash on May 2 nightly of torch_xla, and have confirmed that the equivalent code doesn't crash on CPU.
Hi @ttdd11,
I will further debug what is going on with the environment.
Meanwhile, we landed a new feature assume_pure in nightly and I tried it it can solve the lstsq in torchxla (which might be easier as going to torchax is a bigger change.
First replace torch and torch_xla with nightly:
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
# Edit `cp310-cp310` to fit your desired Python version as needed
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
-f https://storage.googleapis.com/libtpu-wheels/index.html
Then, this bit of code:
import torch as torch
import torch_xla as xla
import torch_xla.core.xla_model as xm
from torch_xla.experimental.assume_pure as assume_pure
device = xm.xla_device()
@assume_pure
def lstsq(x, y):
return torch.linalg.lstsq(x, y).solution
with xla.step():
diff = torch.randn(8,3,1,requires_grad=True,device=device)
A = torch.randn(8,251,3,requires_grad=True,device=device)
B = torch.randn(8,251,1,requires_grad=True,device=device)
C = torch.randn(8,251,1,requires_grad=True,device=device)
D = torch.randn(8,251,1,requires_grad=True,device=device)
E = A*B
F = C*D
ref = lstsq(E, F)
loss = torch.mean(ref - diff)
loss.backward()
print(A.grad)
Gave
tensor([[[-1.1189e-05, -8.5187e-06, -9.6764e-06],
[-1.1518e-06, -1.5361e-06, -1.7471e-06],
[-5.4305e-06, -3.1749e-05, -3.6162e-05],
...,
[-1.2812e-04, -6.6418e-05, -7.5332e-05],
[-3.5903e-06, 1.8855e-06, 2.1583e-06],
[-1.0248e-06, 9.3509e-08, 1.0934e-07]],
[[ 2.0873e-05, 1.6173e-05, 1.2404e-05],
[ 1.0282e-04, 7.8961e-05, 9.3233e-05],
[-2.4210e-05, -1.8589e-05, -2.2111e-05],
...,
[ 2.5044e-05, 1.9359e-05, 1.7007e-05],
[-1.6464e-05, -1.2604e-05, -1.6739e-05],
[ 1.2520e-06, 9.6230e-07, 1.0983e-06]],
[[-9.2494e-05, -8.7698e-05, -1.1561e-04],
[-3.2149e-05, -3.0673e-05, -4.0150e-05],
[ 8.2621e-05, 7.7064e-05, 1.0350e-04],
...,
[ 6.0826e-06, 5.9933e-06, 7.5614e-06],
[ 1.4669e-05, 1.3372e-05, 1.8434e-05],
[-2.2157e-04, -2.3709e-04, -2.7201e-04]],
...,
[[-1.0429e-05, -1.9833e-05, -7.0310e-07],
[ 1.1624e-04, 4.2236e-05, 1.4492e-04],
[-1.5279e-04, -1.4217e-04, -1.2406e-04],
...,
[-7.5765e-05, -7.5498e-05, -5.7689e-05],
[-5.0151e-05, -4.9623e-05, -3.8456e-05],
[ 4.0996e-05, 4.1791e-05, 3.0495e-05]],
[[-4.6437e-05, -4.3407e-05, -5.4165e-05],
[ 2.2402e-06, 9.0005e-06, -2.7746e-06],
[-1.2474e-05, -1.4912e-05, -1.2014e-05],
...,
[-7.2032e-05, -6.6849e-05, -8.4396e-05],
[ 7.0603e-05, 2.4920e-04, -6.0555e-05],
[ 1.1091e-04, 2.4513e-04, 1.9026e-05]],
[[ 5.7307e-05, 4.5501e-05, 5.6519e-05],
[-1.1200e-04, -1.0720e-04, -1.3477e-04],
[-4.0535e-05, -3.6372e-05, -4.5547e-05],
...,
[-5.3763e-06, -5.6098e-06, -7.0862e-06],
[ 1.9833e-05, 1.5748e-05, 1.9561e-05],
[ 2.1044e-05, 1.5666e-05, 1.9367e-05]]], device='xla:0')
Please try this instead and let me know if that works for you or not, and sorry for going back and forth with different suggestions.
Pass to @qihqi
@qihqi @tengyifei thanks for the help here.
We were able to give this a try and at first it looked promising.
For starters - the dataloader setup needed some modifications. We have had problems with them in the past and we have used this: self.loader = torch.utils.data.DataLoader(self.dataset, batch_size=None, shuffle=False, drop_last=False, num_workers=1,persistent_workers=True, pin_memory = True,collate_fn=self.collate_fn)
however with this install - we are getting device errors with anything other than number_workers = 0 which eliminates pin_memory and persistent workers. This may be the cause of the next issue but it's hard to know.
With the number of workers set to 0 - we run the same train (that has run for 100+ epochs with torch xla 2.6), and after a certain number of steps that isn't reproducible (~1000 so ~6000000 images) we get a memory error:
if not torch.isnan(loss).any() and torch.is_nonzero(loss):
RuntimeError: Bad StatusOr access: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 126.56M. That was not possible. There are 46.46M free.; (2x2x0_HBM0)
If we remove that line - the loss goes to zero instead of hitting the memory error.
I'm not entirely sure where to go from here to trouble shoot this. The only two changes are the data loader and code shown above.
Thanks again.
Edit:
Ran on 2.6 with the same data loader and it works as expected. Will run tomorrow without the assume pure to see if that's the issue or if it's something else with nightly torch xla.
@ttdd11 wondering does this issue show up in single device, multi-processing (torch_xla.spawn()), or SPMD (https://pytorch.org/xla/master/spmd.html) execution mode?
I'll try to reproduce these issues by adapting the script in https://github.com/pytorch/xla/issues/8953#issuecomment-2844826841.
In the meantime, @ttdd11 if you have a reproducer script for one or more of the issues encountered, that would be helpful.
@tengyifei apologies for my delay here - we will put together an example where the dataloader issue is occurring and hopefully one where the memory becomes an issue.
@tengyifei Apologies for the delay. Unfortunately we don't have a minimal example that replicates all issues - but here is a starting point for some. I believe that the issues we are seeing are related to the dataloader. Here is a short example of one issues with nightly dataloader:
import torch as torch
import torch_xla
#defines for sizing
batch = 48
channel = 3
width = 480
height = 480
# Define a custom dataset
from torch.utils.data import IterableDataset, DataLoader
class InfiniteDataset(IterableDataset):
def __init__(self):
gg = 0
def __iter__(self):
while True:
yield [torch.randn(channel,width,height),torch.randn(251,3),torch.randn(3,1)]
#model class
from torchvision.models.efficientnet import efficientnet_v2_l
from torch import nn
class Backbone(nn.Module):
"""
use SpatialAtt + ChannelAtt
"""
def __init__(self):
super().__init__()
self.backbone = efficientnet_v2_l()
self.conv = nn.Conv1d(1000,251*3, kernel_size=1, stride=1,padding=0,bias=True)
def forward(self, x):
b = x.shape[0]
y = self.backbone(x)
y = y.unsqueeze(dim = 2)
y = self.conv(y)
y = y.view([b,251,3])
return y
#train update
import numpy as np
def _train_update(step, loss):
if torch.is_tensor(loss):
loss = np.mean(loss.detach().cpu().numpy())
update_data = ['Training','Step={}'.format(step),'Loss={:.5f}'.format(loss)]
print('|', ' '.join(item for item in update_data if item), flush=True)
def train_routine():
#Note - these need to be included in this routine or this error appears
# raise RuntimeError('Runtime is already initialized. Do not use the XLA
# RuntimeError: Runtime is already initialized. Do not use the XLA device before calling xmp.spawn.
import torch_xla.core.xla_model as xm
from torch_xla.amp import autocast
from torch_xla.amp import syncfree
import torch_xla.distributed.parallel_loader as pl
#pure least squares
from torch_xla.experimental.assume_pure import assume_pure
@assume_pure
def lstsq(x, y):
return torch.linalg.lstsq(x, y).solution
device = xm.xla_device()
print("device")
model = Backbone()
model = model.to(device=device)
model.train()
optimizer = syncfree.Adam(
model.parameters(),
lr=0.000001,
weight_decay=0.0000001)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10000000, eta_min=0.000001, last_epoch=-1)
def train_loop_fn(loader):
step = 0
for s, (feed,label,diff) in enumerate(loader):
step += 1
with autocast(xm.xla_device()):
out = model(feed)
ref = lstsq(out, label)
loss = torch.mean(ref - diff)
if not torch.isnan(loss).any() and torch.is_nonzero(loss):
loss.backward()
xm.optimizer_step(optimizer)
lr_scheduler.step()
if step % 20 == 0:
xm.add_step_closure(_train_update,args=(step, loss))
else:
print('nan or zero',flush = True)
infinite_dataset = InfiniteDataset()
#this line causes the issue
#dataloader = DataLoader(infinite_dataset, batch_size=batch,shuffle=False, drop_last=False, num_workers=2,persistent_workers=True, pin_memory = True)
#removing persistant workers and pinning memory allows it to continue
dataloader = DataLoader(infinite_dataset, batch_size=batch,shuffle=False, drop_last=False, num_workers=2)
train_loader = pl.MpDeviceLoader(dataloader, device)
while True:
train_loop_fn(train_loader)
def _mp_fn(index):
train_routine()
sys.exit(21)
def main(args):
torch_xla.launch(_mp_fn, args=())
if __name__ == '__main__':
import sys
main(sys.argv[1:])
In 2.7 (see commented line above) - we are unable to pin memory/have persistent workers.
Historically we have used 1 worker, pinning memory, and persistent workers as it allowed the trains to run uninterrupted. Even in torch 2.6 - if we use a worker number greater than 1 the trains stop at some point that isn't deterministic - the dataloader is the bottleneck when using 1 worker.
I'm not entirely sure how to see what is going on here - any debugging advice would be greatly appreciated.
EDIT: The issue with the dataloader stopping occurs when n workers >= 1, with 0 it does not occur.
@junjieqian
@tengyifei I have a minimal example that brings in a bit more of our process - however we are finding that the use_pure is quite a bit slower than when we first started to experiment. What would be the best method for us to test this? I would like to verify that it's a change to packages and not our process as we finalize this.
Hi @ttdd11 , thanks for the update! I take this task now and am happy to support.
quite a bit slower Would you mind sharing more? I assume you use different package versions, if so, what different versions? Also, it would be really appreciated if you can share the minimal example with setups to run. I will try to repro locally and further investigate it.