bad_alloc on program exit
🐛 Bug
I ran into an issue where on program exit, torch-xla throws a bad_alloc exception.
To Reproduce
Run the following example:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np
class Model(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Sequential(
*[nn.Linear(120, 120) for _ in range(20)]
)
def forward(self, x):
return self.linear(x)
if __name__ == "__main__":
device = xm.xla_device()
input = torch.ones([1, 120]).to(device)
model = Model().to(device)
parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)
out = model(input)
print("Shape of out :", out.shape)
Expected behavior
Note the above example doesn't throw a bad_alloc exception every time. There is a race condition because of which it would fail intermittently. The correct behavior should be no bad_alloc exception on program exit and it should exit gracefully
Environment
- Reproducible on XLA backend [CPU/TPU]:
- torch_xla version: 1.11
Additional context
After digging, the bad_alloc happens because on program exit here: https://github.com/pytorch/xla/blob/master/third_party/xla_client/xrt_computation_client.cc#L1695 there is an Activate call. That Activate call will internally result in ReleaseHandles getting called which eventually results in GetCacheKey() getting called. The static variable here returns a garbage value which is very long and since GetCacheKey does an StrCat it results in a bad_alloc.
The reason this might be happening is because the Wait after Activate doesn’t guarantee that the Runner() is completely executed. Since this is the last thing before program exit, I think it is missing: triggered_task->Stop() which joins all the threads.
Thanks Rahul! Can you elaborate a bit on Wait after Activate doesn’t guarantee that the Runner() is completely executed?
What I meant is, when Activate() is called, it would result in Runner() method getting called which in turn would call the HandleReleases as part of this .
However, I think the WaitForRun() may not be successfully waiting on the function inside Runner to complete.
Make sense, adding triggered_task->Stop() sgtm.
I can create the PR for this. Thanks for checking in!