xla icon indicating copy to clipboard operation
xla copied to clipboard

PJRT_Client_Destroy is not invoked in PT/XLA 2.8.0+

Open rajkthakur opened this issue 3 months ago • 11 comments

🐛 Bug

We are observing a memory leak issue as PJRT_Client_Destroy is not invoked in PT/XLA 2.8.0+ causing backend runtime to be not properly cleaned for AWS neuron devices. This issue is also reproducible with CPU device. Investigation reveals that this is related to issue #9384, where the ComputationClient is created using a raw pointer rather than a unique_ptr in runtime.cpp. This behavior persists in the current development branch of PyTorch/XLA.

We tested that converting from raw pointer to smart pointer fixes the issue, likely because the destructor is properly called.

To Reproduce

Steps to reproduce the behavior:

  1. Create Python 3.10+ virtual env
  2. pip install torch==2.8.0 torch-xla==2.8.0
  3. RUN export TF_CPP_MIN_LOG_LEVEL=0; export TF_CPP_VMODULE="cpu_client=1"; export NEURON_RT_LOG_LEVEL=DEBUG; export PJRT_DEVICE=CPU
  4. RUN python -c "import torch_xla; device=torch_xla.device()"
  5. you will notice logs like below where PjRtCpuClient created. is logged but PjRtCpuClient destroyed is not logged
(pt_28) ubuntu@ip-redacted:~$ python -c "import torch_xla; device=torch_xla.device()"
/home/ubuntu/pt_28/lib/python3.10/site-packages/torch/cuda/__init__.py:182: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
  return torch._C._cuda_getDeviceCount() > 0
2025-10-03 04:17:42.142612: I external/xla/xla/pjrt/cpu/cpu_client.cc:334] PjRtCpuClient created.
  1. If you run the same steps[1-5] in PT/XLA 2.7.0 you will observe that the client is properly destroyed.
(pt_27) ubuntu@ip-redacted:~$ python -c "import torch_xla; device=torch_xla.device()"
/home/ubuntu/pt_27/lib/python3.10/site-packages/torch/cuda/__init__.py:174: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
  return torch._C._cuda_getDeviceCount() > 0
2025-10-03 03:47:17.209860: I external/xla/xla/pjrt/cpu/cpu_client.cc:395] TfrtCpuClient created.
/home/ubuntu/pt_27/lib/python3.10/site-packages/torch/cuda/__init__.py:789: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
/home/ubuntu/pt_27/lib/python3.10/site-packages/torch/cuda/__init__.py:991: UserWarning: CUDA initialization: The NVIDIA driver on your system is too old (found version 11040). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
  r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
2025-10-03 03:47:19.926677: I external/xla/xla/pjrt/cpu/cpu_client.cc:398] TfrtCpuClient destroyed.
(pt_27)

Expected behavior

It is expected that PJRT_Client_Destroy is invoked on program termination logging PjRtCpuClient destroyed

Environment

  • Reproducible on XLA backend [CPU/TPU]: CPU, Neuron
  • torch_xla version: 2.8, 2.8.1, Top of Tree

Additional context

rajkthakur avatar Oct 03 '25 04:10 rajkthakur

Are frameworks supposed to call PJRT_Client_Destroy on program termination?

While I believe we should be calling it before terminating the program, I didn't really find any documentation regarding how frameworks should be managing PJRT plugins.

I think it would be insightful to understand why is this a problem with Neuron, i.e. whether it's possible to cleanup Neuron resources on program termination without the need to call PJRT_Client_Destroy.

ysiraichi avatar Oct 13 '25 18:10 ysiraichi

After looking around, I think we should be calling PJRT_Client_Destroy at some point (there are now 2 issues on this #9675 #9679). Most of the discussion was done in this thread.

In summary: I think that the shutdown idea should work.

Zhanyong did correctly point out that:

I think the best course of action is to fix the hang, as implementing Shutdown correctly adds significant complexity to the design.

That said, here's how Shutdown should work if done correctly: it should allow in-flight computation that needs the client to finish, and it should let new computation (if any) that wants to use the client fail to get the client. This means we'll likely need to use a shared_ptr to hold the client (so that in-flight computation can extend its lifespan).

As you can see, this is doable but not trivial. Hence my advice to avoid it.

Although I do agree that the best course of action is indeed to fix the hanging problem in Neuron, as I have already mentioned in this issue, I believe we should be calling PJRT_Client_Destroy (at some point) to let the PjRt plugin know that this client is no longer active. The reason being that the PjRt plugin might interact with an external long-running process, which allocates resources for the client, waiting for the plugin's signal on PJRT_Client_Destroy call to dispose of the allocated resources.

Claim: calling a new runtime::ShutdownComputationClient() in the end of the if statement below is enough to satisfy the requirements Zhanyong described.

https://github.com/pytorch/xla/blob/d291621f583574f575888da33eaabe866056592c/torch_xla/csrc/init_python_bindings.cpp#L270-L278

Explanation: the XLAGraphExecutor::WaitDeviceOps() call below does the work so that we can safely destroy the PjRt client by:

  • Calling LockDevices(), which waits until all in-flight threads that were scheduled by thread::Schedule(...) are done
  • Calling ComputationClient::WaitDeviceOps(), which makes sure all the pending computation execution by OpenXLA are done

https://github.com/pytorch/xla/blob/d291621f583574f575888da33eaabe866056592c/torch_xla/csrc/init_python_bindings.cpp#L263-L268

Assuming that PrepareToExit() function is being called in the Python interpreter shutdown process, I think we won't try to get a reference to the client ever again, since WaitDeviceOps() guaranteed that all threads are already finished.

Finally, I have looked for other code sites where we create new threads, and found no other non-joined threads that make use of the PjRt client.

@zhanyong-wan Let me know what you think.

ysiraichi avatar Oct 20 '25 20:10 ysiraichi

I'm at the PyTorch conference and don't have time to think through this deeply. That said, my concern with this approach is that it's delicate - even if it works today, it can be hard to guarantee that future changes won't accidentally introduce a deadlock or use-after-free. In contrast, the design where we simply let the OS reclaim the resources is much simpler and robust. Therefore I'm still in favor of avoiding the need to shutdown.

zhanyong-wan avatar Oct 22 '25 22:10 zhanyong-wan

In contrast, the design where we simply let the OS reclaim the resources is much simpler and robust.

I do agree with this.

However, I believe that the main question that correctly addresses this is: Should PJRT_Client_Destroy be called whenever a client is no longer being used?

In other words, is this part of the PjRt plugin lifecycle?

  1. If not, then I agree we should leave the OS to reclaim the resources.
  2. Otherwise, it means that not calling PJRT_Client_Destroy is incorrect behavior.

I tend to believe that (2) is the case because of the few recently opened issues: #9675 and #9679. The wording also makes it sound like it. Finally, as a possible use-case:

The reason being that the PjRt plugin might interact with an external long-running process, which allocates resources for the client, waiting for the plugin's signal on PJRT_Client_Destroy call to dispose of the allocated resources.

ysiraichi avatar Oct 23 '25 17:10 ysiraichi

I do think that this can wait until the end of the PyTorch conference. So, no need to make a rushed decision.

@jeffhataws @rajkthakur @saarthak-aws @bhavya01 @qihqi @tengyifei @vanbasten23

It would be nice to have some feedback on this from PjRt plugin developers. Is calling PJRT_Client_Destroy necessary for maintaining a valid plugin state/behavior? What exactly is being run on PJRT_Client_Destroy call? Is there anything like an interaction with a long-running process for resource cleanup?

ysiraichi avatar Oct 23 '25 17:10 ysiraichi

Re: Should PJRT_Client_Destroy be called whenever a client is no longer being used?

I think this question is relevant only if we plan to create another pjrt client later, in which case it's important to clean up the existing one first. If, however, we never plan to create another pjrt client (which is the case when we are exiting), it shouldn't matter if we explicitly destroy the client, as the OS will take care of it.

Re: The reason being that the PjRt plugin might interact with an external long-running process, which allocates resources for the client, waiting for the plugin's signal on PJRT_Client_Destroy call to dispose of the allocated resources.

I'm not aware of such a process, but it does exist, that process should rely on the OS the properly return resources instead of demanding an explicit clean-up call from its client. The reason is that a process might crash and thus there's no guarantee the explicit clean-up call will every be executed.

zhanyong-wan avatar Oct 24 '25 14:10 zhanyong-wan

That makes sense.

I think that, since I also have no knowledge of how TPU PjRt works under the hood, I think we should gather more input. I will try to go after the deprecated CUDA PjRt client. In the meantime, it would be nice to have some feedback from developers of Neuron and TPU PjRt clients.

In case we don't have any new information, I think we should leave it as is, today. It's safer and simpler.

ysiraichi avatar Oct 24 '25 20:10 ysiraichi

Hi @zhanyong-wan @ysiraichi,

As stated earlier, We've implemented proper cleanup mechanisms in our PJRT client that depend on the PJRT_Client_Destroy lifecycle call. The current implementation require explicit runtime shutdown to properly release the resources including device memory allocation.

This have been working for us with both JAX and PyTorch/XLA, I guess the explicit call to PJRT_Client_Destroy is still true in JAX.

I think we should continue with PJRT_Client_Destroy explicit invocation if there is no evidence of this having caused improper race conditions.

rajkthakur avatar Oct 25 '25 21:10 rajkthakur

Hi @rajkthakur , could you investigate why the explicit shutdown is needed for releasing the resource? As mentioned earlier, the OS should release everything when the process terminates.

Re: we should continue with PJRT_Client_Destroy explicit invocation if there is no evidence of this having caused improper race conditions

The concern is less about there's a race condition today. It's more about complexity and robustness. Even if there's no race today, destroying the client at exit time is incompatible with Google C++ style guide's design philosophy and makes the code more delicate / error-prone. Therefore we should try hard to avoid this pattern.

zhanyong-wan avatar Oct 27 '25 19:10 zhanyong-wan

I looked at the CUDA PjRt client, and it seems that it does nothing besides actually deleting the PjRtClient instance. I guess the CUDA driver takes care of cleaning everything up. Does anyone know if there are any PjRt spec available?

@rajkthakur Is it possible for the Neuron driver to, instead of relying just on the PJRT_Client_Destroy call, to keep an eye on the PIDs of users. So that when they terminate, it is able to dispose of the allocated resources?


@zhanyong-wan What if we created a thread-safe wrapper (e.g. using a shared mutex) for ComputationClient*?

struct Wrapper {
  Wrapper(std::unique_ptr<ComputationClient> client) : client_(std::move(client)) {}

  struct Accessor {
    Accessor(Wrapper& w) : w_(w), lock_(w.m_) {
      if (!w_.valid_) {
        // Throw error when trying to access an already invalidated PJRT client.
        throw std::runtime_error("trying to get invalid PJRT client");
      }
    }

    ComputationClient* operator->() {
      return w_.client_.get();
    }

    private:
      Wrapper& w_;

      // Whenever this `Accessor` instance is being used, hold the
      // shared lock, so that it's not possible to invalidate it.
      std::shared_lock<std::shared_mutex> lock_;
  };

  // Enables users to call `client->foo()` directly!
  // No need to fix every call-site.
  Accessor operator->() {
    return Accessor(*this);
  }

  // Deletes the `PjRtClient` and invalidates this wrapper.
  void Invalidate() {
    std::unique_lock<std::shared_mutex> lock(m_);
    valid_ = false;
    // Calls PJRT_Client_Destroy
    client_.reset();
  }

  private:
    std::shared_mutex m_;
    bool valid_ = true;
    std::unique_ptr<ComputationClient> client_;
};

const absl::StatusOr<Wrapper>& GetComputationClientWrapper() {
  static const auto& wrapper = *new absl::StatusOr<Wrapper>(...);
  return wrapper;
}

Then, we could still call ComputationClient methods like client->foo() safely.

  • In order to delete the PjRtClient, we need to acquire the mutex lock, exclusively
    • No Accessor should be alive
  • In order to use the PjRtClient, we need to acquire the shared mutex lock
    • Invalidate() (the only exclusive lock) can not be running
  • If a thread tries to access an invalidated PjRtClient, Accessor will throw an error

The only problem would be if:

const Wrapper& w = UNWRAP(GetComputationClientWrapper());
// Obvious misuse.
ComputationClient* cc = w.operator->().operator->();
// - `cc` lives while the temporary `Accessor` is already destroyed.
// - We might use an invalid `cc`  

Of course, more complexity introduces more bugs in non-obvious places. But I believe this is worth doing because:

  1. Calls PJRT_Client_Destroy as expected
  2. Only the type of the variables explicitly declared need to be changed at call-sites

What do you think?

ysiraichi avatar Oct 27 '25 21:10 ysiraichi

Sorry, @ysiraichi , let's wait for @rajkthakur to look into a fix on their end before introducing more complexity to torch_xla. Thanks!

zhanyong-wan avatar Oct 28 '25 14:10 zhanyong-wan