metric.compute() hangs when using DDP with multiple GPUs
Bug description
I'm using the default Accuracy metric (though it appears to be true for any metric), and calling metric.compute() hangs after the first epoch and never resolves (ran it overnight, never progressed). It seems, as per some print() statements, that the issue is only with metric computation after training epoch ends, not after validation epoch ends. Issue does not happen for when using only 1 gpu or cpu. It is also agnostic of how large the dataset is, I tried with a dataset having only the first 2 batches and got the same result. I see there's another relevant issue (#5930) from 3 years ago, but has no solution (just says to update version and make a new issue).
What version are you seeing the problem on?
v2.4
How to reproduce the bug
In Model(pl.LightningModule).init(self, splits, #more args):
# ...
# without "_metrics" suffix, keys conflicts with nn.ModuleDict, Adding split prefix as well, but not sure if that is needed
self.split_metrics = nn.ModuleDict({f'{split}_metrics': nn.ModuleDict({f'{split}_{name.replace(".","")}': metric for name, metric in {
'acc': torchmetrics.Accuracy(task='binary'),
# ... more metrics
}.items()}) for split in splits})
# ...
(Note: I considered using a metric collection, but some of my metrics need different inputs and I couldn't figure out how to account for that)
In Model.(pl.LightningModule)._step(self, batch, batch_idx, *, split, **kwargs):
# ...
for name, metric in self.split_metrics[f'{split}_metrics'].items():
metric.update(y_pred_prob if needs_probability(name) else y_pred_label, y)
# ...
Relevant overloads in Model(pl.LightningModule):
def training_step(self, *args, **kwargs):
return self._step(*args, **kwargs, split='train')
def validation_step(self, *args, **kwargs):
return self._step(*args, **kwargs, split='val')
def on_train_epoch_end(self, *args, **kwargs):
self._on_epoch_end(*args, **kwargs, split='train')
def on_validation_epoch_end(self, *args, **kwargs):
self._on_epoch_end(*args, **kwargs, split='val')
Error messages and logs
I put some print statements that show it hangs for train_acc. This is with 2 GPUs.
Sanity Checking DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 2.80it/s][0] VALIDATION EPOCH END
[0] ENTERING _on_epoch_end :: val
[0] log dict initialized :: val
[0] now in loop for val_acc :: val
[1] VALIDATION EPOCH END
[1] ENTERING _on_epoch_end :: val
[1] log dict initialized :: val
[1] now in loop for val_acc :: val
[1] computed val_acc :: val
[1] logged val_acc :: val
[1] reset val_acc :: val
[0] computed val_acc :: val
[0] logged val_acc :: val
[0] reset val_acc :: val
Epoch 0: 0%| | 0/2 [00:00<?, ?it/s][rank0]:[W823 21:27:11.992236756 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[rank1]:[W823 21:27:11.034518963 reducer.cpp:1400] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.34it/s, v_num=26[0] VALIDATION EPOCH END: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 5.13it/s]
[1] VALIDATION EPOCH END
[0] ENTERING _on_epoch_end :: val
[1] ENTERING _on_epoch_end :: val
[0] log dict initialized :: val
[1] log dict initialized :: val
[0] now in loop for val_acc :: val
[1] now in loop for val_acc :: val
[1] computed val_acc :: val
[0] computed val_acc :: val
[1] logged val_acc :: val
[0] logged val_acc :: val
[1] reset val_acc :: val
[0] reset val_acc :: val
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00, 0.96it/s, v_num=26, loss/val=0.172][1] TRAINING EPOCH END
[1] ENTERING _on_epoch_end :: train
[1] log dict initialized :: train
[1] now in loop for train_acc :: train
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0
#- PyTorch Version (e.g., 2.4): 2.4.0
#- TorchMetrics Version: 1.4.1
#- Python version (e.g., 3.12): 3.11.9
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: 12.1
#- GPU models and configuration: 8 x Tesla V100-SXM2-16GB
#- How you installed Lightning(`conda`, `pip`, source): pip
More info
No response
Im also getting this problem
with latest version (main branch on git) the problem doesn't exists. With latest docker release the issue is also present for me.
@AmruthPillai maybe publishing 4.2.0?