torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

SemanticSegmentationTask: add class-wise metrics

Open robmarkcole opened this issue 1 year ago • 34 comments

Addresses https://github.com/microsoft/torchgeo/issues/2121 for segmentation. Mostly copied from @isaaccorley as here - he is additionally passing on_epoch=True which is NOT adopted here

Output metrics for ChaBud binary task with labels=['background', 'burned_area'] This dataset nicely illustrates why class labels are required - burned_area is minority class and is not learnt

[{'test_loss': 450.0233459472656,
  'test_multiclassaccuracy_background': 0.9817732572555542,
  'test_multiclassaccuracy_burned_area': 0.006427088752388954,
  'test_AverageAccuracy': 0.4941001236438751,
  'test_AverageF1Score': 0.4793838560581207,
  'test_AverageJaccardIndex': 0.4542679488658905,
  'test_multiclassfbetascore_background': 0.9489027857780457,
  'test_multiclassfbetascore_burned_area': 0.009864915162324905,
  'test_multiclassjaccardindex_background': 0.9035778641700745,
  'test_multiclassjaccardindex_burned_area': 0.004958001431077719,
  'test_OverallAccuracy': 0.9036323428153992,
  'test_OverallF1Score': 0.9036323428153992,
  'test_OverallJaccardIndex': 0.8265058994293213,
  'test_multiclassprecision_background': 0.9189077615737915,
  'test_multiclassprecision_burned_area': 0.0312582366168499,
  'test_multiclassrecall_background': 0.9817732572555542,
  'test_multiclassrecall_burned_area': 0.006427088752388954}]

robmarkcole avatar Jun 19 '24 10:06 robmarkcole

Given that most metrics of interest are broken (e.g., all of them when average="macro" and ignore_index is specified (https://github.com/Lightning-AI/torchmetrics/pull/2443) andJaccardIndex which outputs NaN when average==macro instead when you try to take absent and ignored classes into account with zero_division (https://github.com/Lightning-AI/torchmetrics/issues/2535)), should we make an effort to see if and how we could add our own?

I'm saying this because these are only the issues I've found so far, but I've also noticed other suspicious things like the fact that my classwise recall values are not the same as those in the confusion matrix when you normalize it with respect to ground truth (I haven't checked if this is also the case with precision, so when the matrix is normalized column-wise). I'm also pretty confident that if all of this is wrong then micro averaging is also probably wrong.

I should be pretty easy to compute all these metrics straight from the confusion matrix (assuming it at least is correct) and I've actually tried to reimplent them this way but it hasn't really been a priority because I’ve found that all these wrong (?) values are basically a lower bound of the actual ones. If you look at the official implementations, this is actually what they are doing, and my guess is that they have a bug in their logic later on. But indeed all these metrics inherit from StatScores, basically the confusion matrix.

I’m actually pretty dumbfounded these issues are not a top priority for the TorchMetrics team and instead they focus on adding to their docs but to each their own…

DimitrisMantas avatar Jun 20 '24 08:06 DimitrisMantas

@DimitrisMantas good call on my ignoring the ignore_index.! In fairness they do address issues, but have a long backlog. When I made some noise they addressed https://github.com/Lightning-AI/torchmetrics/pull/2198 My opinion is it is better to work with torchmetrics to address the issues, rather than implement from scratch here. I see your comment at https://github.com/Lightning-AI/torchmetrics/issues/2535#issuecomment-2143514389 so perhaps a pragmatic approach is not to add new metrics that we have concerns about, but also to create specific issues which track these concerns

robmarkcole avatar Jun 20 '24 08:06 robmarkcole

Sure, that makes sense; please excuse the rant haha.

DimitrisMantas avatar Jun 20 '24 09:06 DimitrisMantas

Applied on_epoch=True, to all steps for consistency - this results in both per epoch and per step being reported for train only - perhaps this is why @isaaccorley did not apply to train?

train_loss_epoch | 0.028535427525639534
train_loss_step | 0.00008003244874998927
train_AverageAccuracy_epoch | 0.9101453423500061
train_AverageAccuracy_step | 0.9124529361724854
image image

Note that Val is unaffected:

val_AverageAccuracy | 0.8227439522743225

For a task with 2 classes there are a grand total of Metrics (52) being reported between train & val

robmarkcole avatar Jun 20 '24 09:06 robmarkcole

I just set to be explicit but I think that pytorch lightning or torchmetrics auto sets on_epoch to be False for training and True for all else.

isaaccorley avatar Jun 20 '24 12:06 isaaccorley

You need to set both on_step and on_epoch to get logs only per step or per epoch.

DimitrisMantas avatar Jun 20 '24 12:06 DimitrisMantas

@DimitrisMantas now just performing on_step for train loss, so a more manageable 36 metrics now

robmarkcole avatar Jun 20 '24 13:06 robmarkcole

Not sure about this failing test ValueError: Problem with given class_path 'torchgeo.trainers.SemanticSegmentationTask'

robmarkcole avatar Jun 21 '24 08:06 robmarkcole

Must be an issue with on of the minimum versions of the package since it's passing for the other tests.

isaaccorley avatar Jun 21 '24 14:06 isaaccorley

We can definitely increase the min version of torchmetrics if we need to.

adamjstewart avatar Aug 06 '24 11:08 adamjstewart

If we do this, we should do it for every trainer, not just segmentation. Want to be consistent here.

adamjstewart avatar Aug 06 '24 11:08 adamjstewart

@adamjstewart I ran the test on my branch and they pass - are they somehow run differently on CICD?

tests/trainers/test_segmentation.py ...........................................                                                                                                                             [100%]

================================================================================================ warnings summary =================================================================================================
tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-spacenet1]
tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-spacenet1]
  /home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 2 is 0. This might result in an empty dataset.
    warnings.warn(f"Length of split at index {i} is 0. "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================= 43 passed, 31 deselected, 2 warnings in 55.59s ==================================================================================

robmarkcole avatar Aug 06 '24 13:08 robmarkcole

The failing tests are for the minimum supported version of our dependencies. If you install the same version of torchmetrics we're using in https://github.com/microsoft/torchgeo/blob/main/requirements/min-reqs.old, you should be able to reproduce the bug. Feel free to increase that version if needed.

adamjstewart avatar Aug 06 '24 13:08 adamjstewart

@adamjstewart tests passing

robmarkcole avatar Aug 06 '24 14:08 robmarkcole

The new min is pretty recent. Did you ever figure out why older versions were raising an error?

adamjstewart avatar Aug 06 '24 15:08 adamjstewart

This PR does a lot of stuff:

  • Add macro, micro, and classwise averaging
  • Switch from MulticlassAccuracy to Accuracy, etc.
  • Modify on_epoch/on_step logging frequency

Can we split these into separate PRs so it's easier to review and clearer which changes necessitate what? The first couple are likely easy to approve, the latter may require justification. I'm also worried that we'll change these only for segmentation and forget about the rest and end up with non-uniform metrics.

adamjstewart avatar Aug 06 '24 15:08 adamjstewart

@adam I also tried torchmetrics==1.1.0 and got the failures, but didn't dig into it

robmarkcole avatar Aug 06 '24 16:08 robmarkcole

@adamjstewart have dropped the on_epoch/on_step modifications from this MR. It feels manageable to me now, but can split further if required

robmarkcole avatar Aug 07 '24 13:08 robmarkcole

@adamjstewart I'm inclined to close this PR as I don't feel confident I understand the behaviour of torchmetrics in this implementation. Elsewhere I am using the on_stage_epoch_end hooks and feel confident I do understand the behaviour with that approach. Overall I think this should be a change we make from a place of understanding, and in smaller steps than this PR takes

robmarkcole avatar Aug 07 '24 14:08 robmarkcole

torchmetrics=1.1.0 test errors here

MeanAveragePrecision(), kwargs = {'average': 'macro'}
...
ValueError: Unexpected keyword arguments: `average`

See this was added in 1.1.1

robmarkcole avatar Aug 07 '24 14:08 robmarkcole

After discussion with torchmetrics devs, created https://github.com/Lightning-AI/torchmetrics/issues/2683

robmarkcole avatar Aug 08 '24 12:08 robmarkcole

That's such a complicated minimal reproducible example lol.

adamjstewart avatar Aug 08 '24 16:08 adamjstewart

I tried making a self-contained minimal reproducible example but couldn't get one working and gave up.

adamjstewart avatar Aug 21 '24 09:08 adamjstewart

It just hit me that we should be a bit careful with which metrics we add to avoid unnecessary computation; class-wise accuracy and recall are the same thing and so are micro-averaged accuracy, precision, and recall.

DimitrisMantas avatar Aug 27 '24 14:08 DimitrisMantas

Any sense of how much these metrics actually add to processing time? If it isn't noticeable by a human, I don't particularly care about the overhead.

adamjstewart avatar Aug 27 '24 14:08 adamjstewart

Haven't measured it but doubt it's much.

DimitrisMantas avatar Aug 27 '24 14:08 DimitrisMantas

I believe Lightning offers tools for profiling

robmarkcole avatar Aug 27 '24 16:08 robmarkcole

They do, see https://torchgeo.readthedocs.io/en/latest/user/contributing.html#i-o-benchmarking

adamjstewart avatar Aug 28 '24 11:08 adamjstewart

@adamjstewart @DimitrisMantas per this comment we should be using the _epoch_end hooks https://github.com/Lightning-AI/torchmetrics/issues/2683#issuecomment-2331146337

robmarkcole avatar Sep 05 '24 10:09 robmarkcole

I see the issue, but I must be missing something because my own code uses the standard logging tools and metric collections work just fine.

Altough by "work", I mean I don't get an error. Other than that, I found out a couple of days ago that the diagonal of my confusion matrix doesn't match the class accuracies (which it should), so I'm obviously not using the API correctly...

Edit: I have at least one mistake where I do self.log_dict(metrics(input, target). The docs says this is wrong.

Edit 2: Aaaaand I finally got your error...

DimitrisMantas avatar Sep 05 '24 10:09 DimitrisMantas