torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

MeanAveragePrecision is slow

Open senarvi opened this issue 2 years ago • 19 comments

🐛 Bug

It's extremely slow to compute the mean-average-precision since torchmetrics > 0.6.0.

To Reproduce

I noticed that my training times have almost doubled since I upgraded torchmetrics from 0.6.0, because validation using the MAP / MeanAveragePrecision metric is so much slower. During validation steps I call update(), and in the end of a validation epoch I call compute() on the MeanAveragePrecision object.

I calculated the time that spent inside compute() with different torchmetrics versions:

  • torchmetrics 0.6.0: 12 s
  • torchmetrics 0.6.1: didn't work for some reason
  • torchmetrics 0.6.2: 9.5 min
  • torchmetrics 0.7.0: 9.4 min
  • torchmetrics 0.7.1: 1.9 min
  • torchmetrics 0.7.2: 2.0 min
  • torchmetrics 0.7.3: 1.9 min
  • torchmetrics 0.8.0: 4.5 min
  • torchmetrics 0.8.1: 4.6 min
  • torchmetrics 0.8.2: 4.6 min

It seems that after 0.6.0 the time to run compute() has increased from 10 seconds to 9.5 minutes. In 0.7.1 it was improved and took 2 minutes. Then in 0.8.0 things got worse again and it took 4.5 minutes to run compute(). This is more than 20x slower than with 0.6.0 and for example when training 100 epochs adds another 7 hours to the training time.

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.6.0 through 0.8.2, installed using pip
  • Python & PyTorch Version (e.g., 1.0): Python 3.8.11, PyTorch 1.10.0
  • Any other relevant information such as OS (e.g., Linux): Linux

senarvi avatar May 10 '22 13:05 senarvi

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar May 10 '22 13:05 github-actions[bot]

interesting and thx for such rigorous comparison per version... Could you pls also share your benchmarking code?

Borda avatar May 10 '22 14:05 Borda

  • So the change from 0.6.0 to 0.6.2+ was due to us changing from using Pycocotools to pure Pytorch https://github.com/PyTorchLightning/metrics/pull/632
  • In 0.7.1 we did improve performance https://github.com/PyTorchLightning/metrics/pull/742
  • I assume that in 0.8.0 it was this PR that made it slower https://github.com/PyTorchLightning/metrics/pull/950

@Borda we can either try to improve our own version to get computational time down (if that is possible) or have an option to have Pycocotools as backend for users

SkafteNicki avatar May 10 '22 14:05 SkafteNicki

@senarvi could you kindly provide the benchmarking script that you have used, such that we have something to use when trying to improve runtime

SkafteNicki avatar May 11 '22 05:05 SkafteNicki

I created a subclass of the YOLO model, but I quess you could subclass any model in the same way. I used proprietary data, but you could use any data. However, if we want to compare each other's results, we should decide what model and data we use. But maybe it would be easiest to just create a bunch of random detections and targets? Anyway, this is more or less the code that I used:

import time
from pl_bolts.models.detection import YOLO

try:
    from torchmetrics.detection import MeanAveragePrecision
except ImportError:
    from torchmetrics.detection import MAP
    MeanAveragePrecision = MAP

class LogTime:
    def __init__(self, name, model):
        self._name = name
        self._model = model

    def __enter__(self):
        self._start_time = time.perf_counter()

    def __exit__(self, exc_type, exc_val, exc_tb):
        end_time = time.perf_counter()
        self._model.log(self._name, end_time - self._start_time)
        return True

class TimingYOLO(YOLO):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._val_map = MeanAveragePrecision()
        self._test_map = MeanAveragePrecision()

    def validation_step(self, batch, batch_idx):
        with LogTime("val/time_detection", self):
            images, targets = self._validate_batch(batch)
            detections, losses = self(images, targets)

        with LogTime("val/time_processing", self):
            detections = self.process_detections(detections)
            targets = self.process_targets(targets)
            self._val_map.update(detections, targets)

    def validation_epoch_end(self, outputs):
        with LogTime("val/time_scoring", self):
            map_scores = self._val_map.compute()
            map_scores = {"val/" + k: v for k, v in map_scores.items()}
            self.log_dict(map_scores, sync_dist=True)
            self._val_map.reset()

    def test_step(self, batch, batch_idx):
        with LogTime("test/time_detection", self):
            images, targets = self._validate_batch(batch)
            detections, losses = self(images, targets)

        with LogTime("test/time_processing", self):
            detections = self.process_detections(detections)
            targets = self.process_targets(targets)
            self._test_map.update(detections, targets)

    def test_epoch_end(self, outputs):
        with LogTime("test/time_scoring", self):
            map_scores = self._test_map.compute()
            map_scores = {"test/" + k: v for k, v in map_scores.items()}
            self.log_dict(map_scores, sync_dist=True)
            self._test_map.reset()

senarvi avatar May 11 '22 06:05 senarvi

@senarvi, yes the model does not really matter, it is the metric computations that are the important. Essentially, this is the code we want to measure:

with LogTime("init"):
   metric = MeanAveragePrecision()
with LogTime("update"):
   for batch in dataloader:
      metric.update(batch)
with LogTime("compute"):
   _ = metric.compute()

This should remove any variation from other lightning code. If you could provide some code to generate random data for testing that should therefore be enough.

SkafteNicki avatar May 11 '22 07:05 SkafteNicki

I can try to do that.

senarvi avatar May 11 '22 08:05 senarvi

Here's some kind of a benchmark script:

import time
import torch

try:
    from torchmetrics.detection import MeanAveragePrecision
except ImportError:
    from torchmetrics.detection import MAP
    MeanAveragePrecision = MAP

total_time = dict()

class UpdateTime:
    def __init__(self, name):
        self._name = name

    def __enter__(self):
        self._start_time = time.perf_counter()

    def __exit__(self, exc_type, exc_val, exc_tb):
        end_time = time.perf_counter()
        if self._name in total_time:
            total_time[self._name] += end_time - self._start_time
        else:
            total_time[self._name] = end_time - self._start_time
        return True

def generate(n):
    boxes = torch.rand(n, 4) * 1000
    boxes[:, 2:] += boxes[:, :2]
    labels = torch.randint(0, 10, (n,))
    scores = torch.rand(n)
    return {"boxes": boxes, "labels": labels, "scores": scores}

with UpdateTime("init"):
    map = MeanAveragePrecision()

for batch_idx in range(100):
    with UpdateTime("update"):
        detections = [generate(100) for _ in range(10)]
        targets = [generate(10) for _ in range(10)]
        map.update(detections, targets)

with UpdateTime("compute"):
    map.compute()

for name, time in total_time.items():
    print(f"Total time in {name}: {time}")

My results:

$ pip install torchmetrics==0.6.0
$ ./map_benchmark.py
Total time in init: 1.5747292000014568
Total time in update: 0.1246876999939559
Total time in compute: 6.245588799996767
$ pip install torchmetrics==0.8.2
$ ./map_benchmark.py
Total time in init: 0.0003580999909900129
Total time in update: 0.08986139997432474
Total time in compute: 151.69804470000963

senarvi avatar May 11 '22 13:05 senarvi

@senarvi cool, this already clearly shows that any improvements that we should be able to do is in compute :)

SkafteNicki avatar May 12 '22 07:05 SkafteNicki

I just ran a profile of the script. profile It is clear that the majority of time is spend in the _find_best_gt_match function. It is important to note that each function call is actually quite fast, however the function gets called a ridicules number of times.

SkafteNicki avatar May 12 '22 12:05 SkafteNicki

It is clear that the majority of time is spend in the _find_best_gt_match function. It is important to note that each function call is actually quite fast, however the function gets called a ridicules number of times.

very nice finding, does any of you want to take it and find some boost? cc: @twsl @PyTorchLightning/core-metrics

Borda avatar May 12 '22 12:05 Borda

Perhaps we can implement this for fast evaluation. It does require c++ however.

https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/evaluation/fast_eval_api.py

24hours avatar Jun 10 '22 06:06 24hours

@24hours I think the way to go here would be to first try and clean up the code before we decide to dispatch to C++

justusschock avatar Jun 10 '22 07:06 justusschock

Hi. Thanks for the great package.

I found the following simple corrections and redundancies about mean_ap.py

https://github.com/Lightning-AI/metrics/blob/fbe06ef4d9c5b5a21ead813e40d371d6fdbd3cd7/src/torchmetrics/detection/mean_ap.py#L479-L481

The above code should be changed as follows. It may not have much effect on speed, but it should be better than scanning every box.

if len(inds) > max_det:
    inds = inds[:max_det]
det = [det[i] for i in inds] 

The above code will appear in other parts of the program, so some changes are necessary.

https://github.com/Lightning-AI/metrics/blob/fbe06ef4d9c5b5a21ead813e40d371d6fdbd3cd7/src/torchmetrics/detection/mean_ap.py#L526-L528

https://github.com/Lightning-AI/metrics/blob/fbe06ef4d9c5b5a21ead813e40d371d6fdbd3cd7/src/torchmetrics/detection/mean_ap.py#L603-L605

I also think there is some redundancy here.

https://github.com/Lightning-AI/metrics/blob/fbe06ef4d9c5b5a21ead813e40d371d6fdbd3cd7/src/torchmetrics/detection/mean_ap.py#L458-L471

l. 464 and l.470 appear to be a duplicate.

mjun0812 avatar Sep 17 '22 14:09 mjun0812

According to my quick survey, this seems to be the most bottlenecked area

https://github.com/Lightning-AI/metrics/blob/fbe06ef4d9c5b5a21ead813e40d371d6fdbd3cd7/src/torchmetrics/detection/mean_ap.py#L734-L739

mjun0812 avatar Sep 18 '22 16:09 mjun0812

Hi @senarvi, thanks for reporting this issue. Could you please try the implementation from #1259 and verify if you can observe any improvement, and if all results are correct? :] I believe there is more space to optimize the metric, but let's go step by step :]

stancld avatar Oct 10 '22 07:10 stancld

Hi @stancld , sorry I didn't have time to respond earlier. I wrote my observations in that pull request. It was indeed a lot faster, but in one case the results were significantly different. I don't normally see such a large variation between test runs.

senarvi avatar Oct 17 '22 17:10 senarvi

Hi @senarvi, thanks for the feedback. Do you have any batch example, you can share, where the results are different so that we can test it and debug please?

stancld avatar Oct 17 '22 17:10 stancld

@stancld Hmm. The data's not public. I wonder if you could debug it using random boxes, like in the speed test. I modified it to make the task a little bit easier and to make sure that the results are deterministic:

import torch
from torchmetrics.detection import MeanAveragePrecision

torch.manual_seed(1)

def generate(n):
    boxes = torch.rand(n, 4) * 10
    boxes[:, 2:] += boxes[:, :2] + 10
    labels = torch.randint(0, 2, (n,))
    scores = torch.rand(n)
    return {"boxes": boxes, "labels": labels, "scores": scores}

batches = []
for _ in range(100):
    detections = [generate(100) for _ in range(10)]
    targets = [generate(10) for _ in range(10)]
    batches.append((detections, targets))

map = MeanAveragePrecision()
for detections, targets in batches:
    map.update(detections, targets)
print(map.compute())

With torchmetrics 0.10.0 I get:

map: 0.1534 map_50: 0.5260 map_75: 0.0336 map_small: 0.1534 mar_1: 0.0449 mar_10: 0.3039 mar_100: 0.5445 mar_small: 0.5445

With the code from your PR I get

map: 0.2222 map_50: 0.7135 map_75: 0.0594 map_small: 0.2222 mar_1: 0.0449 mar_10: 0.4453 mar_100: 2.2028 mar_small: 2.2028

Some recall values are also > 1.

senarvi avatar Oct 17 '22 20:10 senarvi