torchgeo
torchgeo copied to clipboard
SemanticSegmentationTask: add class-wise metrics
Addresses https://github.com/microsoft/torchgeo/issues/2121 for segmentation. Mostly copied from @isaaccorley as here - he is additionally passing on_epoch=True which is NOT adopted here
Output metrics for ChaBud binary task with labels=['background', 'burned_area']
This dataset nicely illustrates why class labels are required - burned_area is minority class and is not learnt
[{'test_loss': 450.0233459472656,
'test_multiclassaccuracy_background': 0.9817732572555542,
'test_multiclassaccuracy_burned_area': 0.006427088752388954,
'test_AverageAccuracy': 0.4941001236438751,
'test_AverageF1Score': 0.4793838560581207,
'test_AverageJaccardIndex': 0.4542679488658905,
'test_multiclassfbetascore_background': 0.9489027857780457,
'test_multiclassfbetascore_burned_area': 0.009864915162324905,
'test_multiclassjaccardindex_background': 0.9035778641700745,
'test_multiclassjaccardindex_burned_area': 0.004958001431077719,
'test_OverallAccuracy': 0.9036323428153992,
'test_OverallF1Score': 0.9036323428153992,
'test_OverallJaccardIndex': 0.8265058994293213,
'test_multiclassprecision_background': 0.9189077615737915,
'test_multiclassprecision_burned_area': 0.0312582366168499,
'test_multiclassrecall_background': 0.9817732572555542,
'test_multiclassrecall_burned_area': 0.006427088752388954}]
Given that most metrics of interest are broken (e.g., all of them when average="macro" and ignore_index is specified (https://github.com/Lightning-AI/torchmetrics/pull/2443) andJaccardIndex which outputs NaN when average==macro instead when you try to take absent and ignored classes into account with zero_division (https://github.com/Lightning-AI/torchmetrics/issues/2535)), should we make an effort to see if and how we could add our own?
I'm saying this because these are only the issues I've found so far, but I've also noticed other suspicious things like the fact that my classwise recall values are not the same as those in the confusion matrix when you normalize it with respect to ground truth (I haven't checked if this is also the case with precision, so when the matrix is normalized column-wise). I'm also pretty confident that if all of this is wrong then micro averaging is also probably wrong.
I should be pretty easy to compute all these metrics straight from the confusion matrix (assuming it at least is correct) and I've actually tried to reimplent them this way but it hasn't really been a priority because I’ve found that all these wrong (?) values are basically a lower bound of the actual ones. If you look at the official implementations, this is actually what they are doing, and my guess is that they have a bug in their logic later on. But indeed all these metrics inherit from StatScores, basically the confusion matrix.
I’m actually pretty dumbfounded these issues are not a top priority for the TorchMetrics team and instead they focus on adding to their docs but to each their own…
@DimitrisMantas good call on my ignoring the ignore_index.! In fairness they do address issues, but have a long backlog. When I made some noise they addressed https://github.com/Lightning-AI/torchmetrics/pull/2198
My opinion is it is better to work with torchmetrics to address the issues, rather than implement from scratch here. I see your comment at https://github.com/Lightning-AI/torchmetrics/issues/2535#issuecomment-2143514389 so perhaps a pragmatic approach is not to add new metrics that we have concerns about, but also to create specific issues which track these concerns
Sure, that makes sense; please excuse the rant haha.
Applied on_epoch=True, to all steps for consistency - this results in both per epoch and per step being reported for train only - perhaps this is why @isaaccorley did not apply to train?
train_loss_epoch | 0.028535427525639534
train_loss_step | 0.00008003244874998927
train_AverageAccuracy_epoch | 0.9101453423500061
train_AverageAccuracy_step | 0.9124529361724854
Note that Val is unaffected:
val_AverageAccuracy | 0.8227439522743225
For a task with 2 classes there are a grand total of Metrics (52) being reported between train & val
I just set to be explicit but I think that pytorch lightning or torchmetrics auto sets on_epoch to be False for training and True for all else.
You need to set both on_step and on_epoch to get logs only per step or per epoch.
@DimitrisMantas now just performing on_step for train loss, so a more manageable 36 metrics now
Not sure about this failing test ValueError: Problem with given class_path 'torchgeo.trainers.SemanticSegmentationTask'
Must be an issue with on of the minimum versions of the package since it's passing for the other tests.
We can definitely increase the min version of torchmetrics if we need to.
If we do this, we should do it for every trainer, not just segmentation. Want to be consistent here.
@adamjstewart I ran the test on my branch and they pass - are they somehow run differently on CICD?
tests/trainers/test_segmentation.py ........................................... [100%]
================================================================================================ warnings summary =================================================================================================
tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-spacenet1]
tests/trainers/test_segmentation.py::TestSemanticSegmentationTask::test_trainer[True-spacenet1]
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataset.py:449: UserWarning: Length of split at index 2 is 0. This might result in an empty dataset.
warnings.warn(f"Length of split at index {i} is 0. "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================= 43 passed, 31 deselected, 2 warnings in 55.59s ==================================================================================
The failing tests are for the minimum supported version of our dependencies. If you install the same version of torchmetrics we're using in https://github.com/microsoft/torchgeo/blob/main/requirements/min-reqs.old, you should be able to reproduce the bug. Feel free to increase that version if needed.
@adamjstewart tests passing
The new min is pretty recent. Did you ever figure out why older versions were raising an error?
This PR does a lot of stuff:
- Add macro, micro, and classwise averaging
- Switch from MulticlassAccuracy to Accuracy, etc.
- Modify on_epoch/on_step logging frequency
Can we split these into separate PRs so it's easier to review and clearer which changes necessitate what? The first couple are likely easy to approve, the latter may require justification. I'm also worried that we'll change these only for segmentation and forget about the rest and end up with non-uniform metrics.
@adam I also tried torchmetrics==1.1.0 and got the failures, but didn't dig into it
@adamjstewart have dropped the on_epoch/on_step modifications from this MR. It feels manageable to me now, but can split further if required
@adamjstewart I'm inclined to close this PR as I don't feel confident I understand the behaviour of torchmetrics in this implementation. Elsewhere I am using the on_stage_epoch_end hooks and feel confident I do understand the behaviour with that approach. Overall I think this should be a change we make from a place of understanding, and in smaller steps than this PR takes
torchmetrics=1.1.0 test errors here
MeanAveragePrecision(), kwargs = {'average': 'macro'}
...
ValueError: Unexpected keyword arguments: `average`
See this was added in 1.1.1
After discussion with torchmetrics devs, created https://github.com/Lightning-AI/torchmetrics/issues/2683
That's such a complicated minimal reproducible example lol.
I tried making a self-contained minimal reproducible example but couldn't get one working and gave up.
It just hit me that we should be a bit careful with which metrics we add to avoid unnecessary computation; class-wise accuracy and recall are the same thing and so are micro-averaged accuracy, precision, and recall.
Any sense of how much these metrics actually add to processing time? If it isn't noticeable by a human, I don't particularly care about the overhead.
Haven't measured it but doubt it's much.
I believe Lightning offers tools for profiling
They do, see https://torchgeo.readthedocs.io/en/latest/user/contributing.html#i-o-benchmarking
@adamjstewart @DimitrisMantas per this comment we should be using the _epoch_end hooks https://github.com/Lightning-AI/torchmetrics/issues/2683#issuecomment-2331146337
I see the issue, but I must be missing something because my own code uses the standard logging tools and metric collections work just fine.
Altough by "work", I mean I don't get an error. Other than that, I found out a couple of days ago that the diagonal of my confusion matrix doesn't match the class accuracies (which it should), so I'm obviously not using the API correctly...
Edit: I have at least one mistake where I do self.log_dict(metrics(input, target). The docs says this is wrong.
Edit 2: Aaaaand I finally got your error...