torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

RuntimeError when using MAP-metric

Open V-Soboleva opened this issue 2 years ago • 17 comments

🐛 Bug

Hi! I am training a detection model and use MAP-metric during validation. I got the following error at the validation_step: RuntimeError: expected scalar type Float but found Bool.

To Reproduce

Pick a faster rcnn model, I used fasterrcnn_resnet50_fpn_v2() from torchvision . Implement validation_step where self.metrics.update(...) is called for the model results and targets and validation_epoch_end where the self.metrics.compute() is called for the previously gathered results.

Code sample

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchmetrics.detection.mean_ap import MeanAveragePrecision


class FasterRCNNModel(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()

        model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn_v2()
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
        self.model = model
        self.metric = MeanAveragePrecision(box_format='xyxy', iou_type='bbox')

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.model(images)        
        self.metric.update(preds, targets)

    def validation_epoch_end(self, outs):
        mAP = self.metric.compute()
        self.log("val/mAP", mAP)
        self.metric.reset()

targets (List[Dict]), containing:

  • boxes (torch.float32)
  • labels (torch.int64)

preds (List[Dict]), containing:

  • boxes (torch.float32)
  • scores (torch.float32)
  • labels (torch.int64)

Error message

  File "/homes/vsoboleva/scripts/pascal_voc/train.py", line 65, in validation_epoch_end
    mAP = self.metric.compute()
  File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/metric.py", line 523, in wrapped_func
    value = compute(*args, **kwargs)
  File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 908, in compute
    precisions, recalls = self._calculate(classes)
  File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 758, in _calculate
    recall, precision, scores = MeanAveragePrecision.__calculate_recall_precision_scores(
  File "/homes/vsoboleva/miniconda3/lib/python3.9/site-packages/torchmetrics/detection/mean_ap.py", line 831, in __calculate_recall_precision_scores
    det_scores = torch.cat([e["dtScores"][:max_det] for e in img_eval_cls_bbox])
RuntimeError: expected scalar type Float but found Bool

Expected behavior

The sel.metric.compute(...) compute values correctly and does not fail with RuntimeError: expected scalar type Float but found Bool.

Environment

  • TorchMetrics 0.9.2 build with pip
  • Python 3.9.12, torch 1.12.0, torchvision 0.13.0
  • OS (e.g., Linux): Ubuntu 20.04.3

Additional context

V-Soboleva avatar Jul 15 '22 20:07 V-Soboleva

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Jul 15 '22 20:07 github-actions[bot]

Hey @V-Soboleva, I've been dealing with the same issue. I think the issue occurs on the 'scores' prediction from the maskrcnn model. It seems that when no detection is found, the model outputs a tensor of type bool. I'm unsure if it's fixed yet but i've just type casted it as follows:

output['scores'] = output['scores'].type(torch.float)

dreaquil avatar Jul 18 '22 09:07 dreaquil

Thank you @dreaquil, but unfortunately it didn't help : (

V-Soboleva avatar Jul 18 '22 10:07 V-Soboleva

Hey @V-Soboleva! Yeah, just ran into the same issue again... Will let you know if I find a resolution.

As a note, I tried to type sanitise the entire pred input which didn't work.

@staticmethod
    def _sanitise_preds(preds: List[Dict]) -> List[Dict]:
        for pred in preds:
            pred["boxes"] = pred["boxes"].type(torch.float)
            pred["scores"] = pred["scores"].type(torch.float)
            pred["labels"] = pred["labels"].type(torch.int)
            if "masks" in pred:
                pred["masks"] = pred["masks"].type(torch.bool)
        return preds

dreaquil avatar Jul 18 '22 11:07 dreaquil

I also noticed that if I run the following code without pytorch lightning it works and gives reasonable values:

metric = MeanAveragePrecision(box_format='xyxy', iou_type='bbox')

with torch.no_grad():
    for batch_id, batch in tqdm(enumerate(dataloader)):
        images, targets = batch
        images = [image.to('cuda') for image in images]
        targets = [{k: v.to('cuda') for k, v in t.items()} for t in targets]
        preds = model(images)
        metric.update(preds, targets)
mAP = metric.compute()

So, the problem is not in predictions or targets.

V-Soboleva avatar Jul 18 '22 11:07 V-Soboleva

I don't think you can exclude problems in predictions (but perhaps in targets). In the above code, you're not doing any weight updates, therefore unlike when training, your models output predictions aren't changing per epoch. It just means that with random initialisation, the output is unlikely to cause an error.

It does likely exclude errors in targets though :)

dreaquil avatar Jul 18 '22 12:07 dreaquil

@V-Soboleva, when you were getting the error, was it at a random epoch or was it immediately?

dreaquil avatar Jul 18 '22 12:07 dreaquil

okay I think I found it:

https://github.com/Lightning-AI/metrics/blob/31c384411bc9a28f4ad2085cf123f68f382b6f82/torchmetrics/detection/mean_ap.py#L505

I believe this type needs to be changed to torch.float32 from torch.bool

Can you check that by changing the above, you're no longer experiencing the issue please?

dreaquil avatar Jul 18 '22 13:07 dreaquil

Hi, don't mind me, just sliding into this thread. I had the same issue. @dreaquil changing the type fixexd the issue for me. Best Simon

Simon128 avatar Jul 18 '22 14:07 Simon128

Thanks for confirming @Simon128. Seems to have fixed it for me too. Will create a PR.

dreaquil avatar Jul 18 '22 14:07 dreaquil

@dreaquil, changing the type helped me as well :) Thank you very much!

V-Soboleva avatar Jul 18 '22 14:07 V-Soboleva

For testing purpose to make sure the bug is indeed fixed, could someone provide a single input example where this fails currently?

SkafteNicki avatar Jul 19 '22 08:07 SkafteNicki

Hi @SkafteNicki, the failure is related to the when the model outputs no detections. I will try get a specific example later this morning.

dreaquil avatar Jul 19 '22 08:07 dreaquil

Hi @V-Soboleva & @Simon128, would you happen to have an easy way to retrieve a single input for the above failure? I'm finding it quite difficult to reproduce again 😅

dreaquil avatar Jul 19 '22 13:07 dreaquil

Hi @dreaquil, @SkafteNicki I tried to reproduce an error and save predictions and targets on which error occures during the training. However, when I tried to compute metrics with these saved tensors without training the error does not occures 🤷‍♀️. But I have a colab notebook where this error occures at 50% of the first epoch (when validation starts) if it could be helpful.

V-Soboleva avatar Jul 19 '22 17:07 V-Soboleva

Hi @SkafteNicki, @V-Soboleva, I am having the same issue as @V-Soboleva. Taking the preds and targets that cause failure during training does not work in a minimal example with just the metric. I have taken @V-Soboleva's colab notebook and put it into a script which is seeded and has limited train batches for quick failure. pascal.txt

dreaquil avatar Jul 20 '22 08:07 dreaquil

This issue happens when using PyTorch 1.12.0 on GPU device, the minimal code to reproduce this issue is shown below.

In [1]: import torch

In [2]: torch.__version__
Out[2]: '1.12.0+cu102'

In [3]: torch.cat([torch.zeros(0, dtype=torch.bool, device="cpu"), torch.zeros(1, dtype=torch.float32, device="cpu")])
Out[3]: tensor([0.])

In [4]: torch.cat([torch.zeros(0, dtype=torch.bool, device="cuda:0"), torch.zeros(1, dtype=torch.float32, device="cuda:0")])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 torch.cat([torch.zeros(0, dtype=torch.bool, device="cuda:0"), torch.zeros(1, dtype=torch.float32, device="cuda:0")])

RuntimeError: expected scalar type Float but found Bool

Although, the CPU implementation and other versions of PyTorch can cast 0-dim bool Tensor to float Tensor. The dtScores should be initialized as a float type because score is usually a real number.

@SkafteNicki

For testing purpose to make sure the bug is indeed fixed, could someone provide a single input example where this fails currently?

The current test case in tests/unittests/detection/test_map.py do not cover this bool float concatenate case. There is only one pair of pred and target in _inputs3. Adding another pair of pred and target can fail the tests. e.g.

_inputs3 = Input(
    preds=[
        [
            dict(
                boxes=Tensor([[258.0, 41.0, 606.0, 285.0]]),
                scores=Tensor([0.536]),
                labels=IntTensor([0]),
            ),
        ],
        [
            dict(boxes=Tensor([]), scores=Tensor([]), labels=Tensor([])),
        ],
    ],
    target=[
        [
            dict(
                boxes=Tensor([[214.0, 41.0, 562.0, 285.0]]),
                labels=IntTensor([0]),
            )
        ],
        [
            dict(
                boxes=Tensor([[1.0, 2.0, 3.0, 4.0]]),
                scores=Tensor([0.8]),  # target does not have scores
                labels=Tensor([1]),
            ),
        ],
    ],
)

aaronzs avatar Jul 29 '22 06:07 aaronzs

What's the recommend workaround for this currently?

austinmw avatar Oct 10 '22 17:10 austinmw

@austinmw should be fixed now by PR #1150. Please try installing from master:

pip install https://github.com/Lightning-AI/metrics/archive/master.zip

which should solve the issue. If not, please report back and we can reopen the issue and try to fix it.

SkafteNicki avatar Oct 11 '22 10:10 SkafteNicki