torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Fix race condition in `_safe_divide` by creating tensor directly on device

Open Copilot opened this issue 5 months ago • 1 comments

What does this PR do?

Fixes a race condition in _safe_divide that could lead to uninitialized values when using non-blocking tensor transfers, particularly affecting MPS devices.

Closes #3095

The Problem

The previous implementation created a tensor on CPU and then transferred it to the target device:

zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(
    num.device, non_blocking=num.device.type != "mps"
)

This caused a race condition when non_blocking=True:

  1. The .to() call returns immediately without waiting for the memory copy to complete
  2. The tensor is used in torch.where() before the copy finishes
  3. This results in uninitialized or incorrect values being read

Issue reporter experienced "sometimes correct default (0.0) but sometimes uninitialized numbers" on MPS devices.

The Solution

Create the tensor directly on the target device:

zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype, device=num.device)

This eliminates the race condition by:

  • Avoiding any CPU-to-device transfers
  • Creating the tensor directly where it's needed
  • Removing device-specific conditional logic

Benefits

  • Fixes the race condition - No more uninitialized values
  • Simpler code - Removed MPS-specific handling
  • Works uniformly - MPS, CUDA, and CPU all handled the same way
  • No sync issues - torch.tensor(..., device=device) doesn't cause CUDA synchronization
  • Better performance - Avoids unnecessary CPU-GPU memory transfers

Testing

Added comprehensive test in tests/unittests/utilities/test_utilities.py that verifies:

  • Basic functionality with division by zero
  • Custom zero_division values
  • Device consistency across CPU and CUDA

All existing tests pass including:

  • 28 utility tests
  • 144 classification tests using _safe_divide
  • Manual verification with metrics using this function (Dice, ExactMatch, etc.)

Related

This is similar to the approach in #3094 which was initially closed due to concerns about CUDA synchronization. However, creating tensors directly on device with torch.tensor(..., device=device) does not cause synchronization, unlike using .to(device). PR #3101 attempted to fix this by disabling non_blocking for MPS, but the race condition could still occur. This PR properly fixes the root cause.

Original prompt

This section details on the original issue you should resolve

<issue_title>Non-blocking call in _safe_divide leads to race condition</issue_title> <issue_description>## 🐛 Bug

There's an apparent race condition here:

https://github.com/Lightning-AI/torchmetrics/blob/e79093558ae48f52cb5882685b07545649171230/src/torchmetrics/utilities/compute.py#L77-L78

When moving the tensor to the target device (MPS in my case), I get sometimes the correct default (0.0) but sometimes uninitialized numbers, screwing up everything.

Environment
  • TorchMetrics version (if build from source, add commit SHA): 1.7.1
  • Python & PyTorch Version (e.g., 1.0): Python 3.12.10, PyTorch 2.7.0
  • Any other relevant information such as OS (e.g., Linux): macOS, Darwin Kernel Version 24.4.0

<agent_instructions>Continue in work from https://github.com/Lightning-AI/torchmetrics/pull/3094</agent_instructions>

Comments on the Issue (you are @copilot in this section)

@Borda Hello, could you please be more specific on "sometimes uninitialized numbers", ideally if we could reproduce it... @SkafteNicki Probably related to Lightning-AI/torchmetrics#2955 and Lightning-AI/torchmetrics#1727. @Borda can we (read: me) get access to a MPS enabled device for debugging? @Borda > get access to a MPS enabled device for debugging?

let talk about it in DM</comment_new>

Fixes Lightning-AI/torchmetrics#3095

✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.


📚 Documentation preview 📚: https://torchmetrics--3284.org.readthedocs.build/en/3284/

Copilot avatar Oct 04 '25 20:10 Copilot

Codecov Report

:x: Patch coverage is 0% with 1 line in your changes missing coverage. Please review. :white_check_mark: Project coverage is 37%. Comparing base (88bca94) to head (8856d20).

:x: Your project check has failed because the head coverage (37%) is below the target coverage (95%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #3284    +/-   ##
=======================================
- Coverage      37%     37%    -1%     
=======================================
  Files         364     349    -15     
  Lines       20096   19901   -195     
=======================================
- Hits         7520    7326   -194     
+ Misses      12576   12575     -1     
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Nov 11 '25 11:11 codecov[bot]