xla icon indicating copy to clipboard operation
xla copied to clipboard

bad_alloc on program exit

Open aws-rhsoln opened this issue 3 years ago • 4 comments

🐛 Bug

I ran into an issue where on program exit, torch-xla throws a bad_alloc exception.

To Reproduce

Run the following example:

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch import nn, einsum
import torch.nn.functional as F
import numpy as np

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Sequential(
            *[nn.Linear(120, 120) for _ in range(20)]
        ) 

    def forward(self, x):
        return self.linear(x)
        


if __name__ == "__main__":
    device = xm.xla_device()
    input = torch.ones([1, 120]).to(device)
    model = Model().to(device)

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
    print('Trainable Parameters: %.3fM' % parameters)
    out = model(input)


    print("Shape of out :", out.shape) 

Expected behavior

Note the above example doesn't throw a bad_alloc exception every time. There is a race condition because of which it would fail intermittently. The correct behavior should be no bad_alloc exception on program exit and it should exit gracefully

Environment

  • Reproducible on XLA backend [CPU/TPU]:
  • torch_xla version: 1.11

Additional context

After digging, the bad_alloc happens because on program exit here: https://github.com/pytorch/xla/blob/master/third_party/xla_client/xrt_computation_client.cc#L1695 there is an Activate call. That Activate call will internally result in ReleaseHandles getting called which eventually results in GetCacheKey() getting called. The static variable here returns a garbage value which is very long and since GetCacheKey does an StrCat it results in a bad_alloc.

The reason this might be happening is because the Wait after Activate doesn’t guarantee that the Runner() is completely executed. Since this is the last thing before program exit, I think it is missing: triggered_task->Stop() which joins all the threads.

aws-rhsoln avatar Sep 22 '22 02:09 aws-rhsoln

Thanks Rahul! Can you elaborate a bit on Wait after Activate doesn’t guarantee that the Runner() is completely executed?

JackCaoG avatar Sep 22 '22 22:09 JackCaoG

What I meant is, when Activate() is called, it would result in Runner() method getting called which in turn would call the HandleReleases as part of this . However, I think the WaitForRun() may not be successfully waiting on the function inside Runner to complete.

aws-rhsoln avatar Sep 22 '22 22:09 aws-rhsoln

Make sense, adding triggered_task->Stop() sgtm.

JackCaoG avatar Sep 23 '22 18:09 JackCaoG

I can create the PR for this. Thanks for checking in!

aws-rhsoln avatar Sep 23 '22 19:09 aws-rhsoln