Fix race condition in `_safe_divide` by creating tensor directly on device
What does this PR do?
Fixes a race condition in _safe_divide that could lead to uninitialized values when using non-blocking tensor transfers, particularly affecting MPS devices.
Closes #3095
The Problem
The previous implementation created a tensor on CPU and then transferred it to the target device:
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(
num.device, non_blocking=num.device.type != "mps"
)
This caused a race condition when non_blocking=True:
- The
.to()call returns immediately without waiting for the memory copy to complete - The tensor is used in
torch.where()before the copy finishes - This results in uninitialized or incorrect values being read
Issue reporter experienced "sometimes correct default (0.0) but sometimes uninitialized numbers" on MPS devices.
The Solution
Create the tensor directly on the target device:
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype, device=num.device)
This eliminates the race condition by:
- Avoiding any CPU-to-device transfers
- Creating the tensor directly where it's needed
- Removing device-specific conditional logic
Benefits
- ✅ Fixes the race condition - No more uninitialized values
- ✅ Simpler code - Removed MPS-specific handling
- ✅ Works uniformly - MPS, CUDA, and CPU all handled the same way
- ✅ No sync issues -
torch.tensor(..., device=device)doesn't cause CUDA synchronization - ✅ Better performance - Avoids unnecessary CPU-GPU memory transfers
Testing
Added comprehensive test in tests/unittests/utilities/test_utilities.py that verifies:
- Basic functionality with division by zero
- Custom
zero_divisionvalues - Device consistency across CPU and CUDA
All existing tests pass including:
- 28 utility tests
- 144 classification tests using
_safe_divide - Manual verification with metrics using this function (Dice, ExactMatch, etc.)
Related
This is similar to the approach in #3094 which was initially closed due to concerns about CUDA synchronization. However, creating tensors directly on device with torch.tensor(..., device=device) does not cause synchronization, unlike using .to(device). PR #3101 attempted to fix this by disabling non_blocking for MPS, but the race condition could still occur. This PR properly fixes the root cause.
Original prompt
This section details on the original issue you should resolve
<issue_title>Non-blocking call in
_safe_divideleads to race condition</issue_title> <issue_description>## 🐛 BugThere's an apparent race condition here:
https://github.com/Lightning-AI/torchmetrics/blob/e79093558ae48f52cb5882685b07545649171230/src/torchmetrics/utilities/compute.py#L77-L78
When moving the tensor to the target device (MPS in my case), I get sometimes the correct default (
0.0) but sometimes uninitialized numbers, screwing up everything.Environment
- TorchMetrics version (if build from source, add commit SHA):
1.7.1- Python & PyTorch Version (e.g., 1.0): Python 3.12.10, PyTorch 2.7.0
- Any other relevant information such as OS (e.g., Linux): macOS, Darwin Kernel Version 24.4.0
<agent_instructions>Continue in work from https://github.com/Lightning-AI/torchmetrics/pull/3094</agent_instructions>
Comments on the Issue (you are @copilot in this section)
@Borda Hello, could you please be more specific on "sometimes uninitialized numbers", ideally if we could reproduce it...@SkafteNicki Probably related to Lightning-AI/torchmetrics#2955 and Lightning-AI/torchmetrics#1727. @Borda can we (read: me) get access to a MPS enabled device for debugging?@Borda > get access to a MPS enabled device for debugging?let talk about it in DM</comment_new>
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.
📚 Documentation preview 📚: https://torchmetrics--3284.org.readthedocs.build/en/3284/
Codecov Report
:x: Patch coverage is 0% with 1 line in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 37%. Comparing base (88bca94) to head (8856d20).
:x: Your project check has failed because the head coverage (37%) is below the target coverage (95%). You can increase the head coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## master #3284 +/- ##
=======================================
- Coverage 37% 37% -1%
=======================================
Files 364 349 -15
Lines 20096 19901 -195
=======================================
- Hits 7520 7326 -194
+ Misses 12576 12575 -1
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.