torchgeo icon indicating copy to clipboard operation
torchgeo copied to clipboard

Multiclass Classification: assert num_classes >=2

Open robmarkcole opened this issue 1 year ago • 4 comments

Summary

Both segmentation and object detection require that the background be included and there is currently a note on these args: num_classes: Number of prediction classes (including the background). Considering every dataaset must have at least 1 class, the min value of num_classes is 2. I propose adding an assertion, to prevent people (like myself!) from forgetting this and setting num_classes=1 for datasets with a single class.

Rationale

This config error has happened to me several times, and can pass silently

Implementation

I suppose we add validation to the BaseTask init

Alternatives

No response

Additional information

No response

robmarkcole avatar Aug 01 '24 13:08 robmarkcole

Not to completely derail what should otherwise be a simple fix, but...

This brings up the question of how we want to handle different forms of classification/semantic segmentation:

  • Binary
  • Multiclass
  • Multilabel

Torchmetrics originally had a single class for Accuracy. In https://github.com/Lightning-AI/torchmetrics/issues/1001, they proposed and implemented separate classes for each of the 3 above types of classification (BinaryAccuracy, etc.). The original plan was to deprecate and remove the old single class, but it seems that plan was aborted at some point.

We should decide whether we want BinaryClassificationTask, etc. or whether we want to add a task='binary', etc. parameter to ClassificationTask.

We could definitely still add such an assertion for now and change it to assert num_classes > 1 if task != 'binary' later if needed.

adamjstewart avatar Aug 04 '24 11:08 adamjstewart

As you point out, binary etc are args torchmetrics accepts, so I think it makes sense to have this functionality with the existing task

robmarkcole avatar Aug 04 '24 11:08 robmarkcole

Just waiting for clarity on whether torchmetrics is planning on supporting the old metrics forever before deciding, but I was leaning towards that too.

adamjstewart avatar Aug 04 '24 11:08 adamjstewart

Looks like I misinterpreted, both are supported.

Is there anything special we need to do in our trainers to support binary and multilabel, or do we literally just need to pass different task values to torchmetrics? If the former, we may want to split, but if the latter, I agree we should just keep the current classes and add a task parameter.

adamjstewart avatar Aug 04 '24 16:08 adamjstewart