torchmetrics
torchmetrics copied to clipboard
MeanAveragePrecision throw error in DDP mode
🐛 Bug
Hi :) ! I am training a detection model (maskrcnn) and it works fine with the MeanAveragePrecision metric in single gpu mode. However, when I activate DDP with 4 GPUs, it throws an error on the first validation phase.
I have investigated and there are more detection labels than there are detections. However, I have gone through and made sure this is never the case when calling update on the metric (I added an assert).
File "/home/andrea/projects/my-project/src/callbacks/detection_metrics.py", line 77, in on_validation_epoch_end
mean_ap = self.val_mean_ap.compute()
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/metric.py", line 523, in wrapped_func
value = compute(*args, **kwargs)
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/detection/mean_ap.py", line 908, in compute
precisions, recalls = self._calculate(classes)
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/detection/mean_ap.py", line 728, in _calculate
ious = {
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/detection/mean_ap.py", line 729, in <dictcomp>
(idx, class_id): self._compute_iou(idx, class_id, max_detections)
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/detection/mean_ap.py", line 468, in _compute_iou
det = [det[i] for i in det_label_mask]
File "/home/andrea/projects/my-project/.venv/lib/python3.8/site-packages/torchmetrics/detection/mean_ap.py", line 468, in <listcomp>
det = [det[i] for i in det_label_mask]
IndexError: tuple index out of range
Additionally, I stepped through each update here and can confirm that each detection length matches item['labels] and item['scores'].
To Reproduce
Steps to reproduce the behavior...
Code sample
Expected behavior
Environment
- TorchMetrics version (and how you installed TM, e.g.
conda
,pip
, build from source): 0.9.2 with pip - Python & PyTorch Version (e.g., 1.0): Python 3.8.10, Pytorch 1.12
- Any other relevant information such as OS (e.g., Linux): Linux