torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Potential Unsupported version of DICE.

Open aymuos15 opened this issue 1 year ago • 5 comments

Bug description

https://lightning.ai/docs/torchmetrics/stable/classification/dice.html

Here, under average parameter of the dice score, there is the option of 'weighted' mentioned. But that triggers an error while using it.

How to reproduce the bug

!pip install torchmetrics
from torch import tensor
from torchmetrics.classification import Dice

preds  = tensor([2, 0, 2, 1])
target = tensor([1, 1, 2, 0])
dice = Dice(average='weighted')
dice(preds, target)

Error messages and logs

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-4-16f87eeaf06b>](https://localhost:8080/#) in <cell line: 6>()
      4 preds  = tensor([2, 0, 2, 1])
      5 target = tensor([1, 1, 2, 0])
----> 6 dice = Dice(average='weighted')
      7 dice(preds, target)

[/usr/local/lib/python3.10/dist-packages/torchmetrics/classification/dice.py](https://localhost:8080/#) in __init__(self, zero_division, num_classes, threshold, average, mdmc_average, ignore_index, top_k, multiclass, **kwargs)
    158         allowed_average = ("micro", "macro", "samples", "none", None)
    159         if average not in allowed_average:
--> 160             raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
    161 
    162         _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)

ValueError: The `average` has to be one of ('micro', 'macro', 'samples', 'none', None), got weighted.

Environment

Google Colab -> (Hence the pip install)

More info

https://colab.research.google.com/drive/1GuBr6kI0ypj0sUDaqGQ8dN0Td-sbYWGI?usp=sharing

aymuos15 avatar Jan 24 '24 20:01 aymuos15

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Jan 24 '24 20:01 github-actions[bot]

I see the issue

Its in the doc string, but not within the code. I guess weighted is not supported (?)

https://github.com/Lightning-AI/torchmetrics/blob/a68455afb9041d1d32c1d6546897fee416abdc41/src/torchmetrics/classification/dice.py#L71

https://github.com/Lightning-AI/torchmetrics/blob/a68455afb9041d1d32c1d6546897fee416abdc41/src/torchmetrics/classification/dice.py#L158

Looks like most of the weighted code was removed

https://github.com/Lightning-AI/torchmetrics/commit/2509448389c2c5bdd3305721e74cedd6bd26c57b

SuperSecureHuman avatar Jan 27 '24 11:01 SuperSecureHuman

Great Find. Thank you. Will keep the issue open for the time-being.

aymuos15 avatar Jan 30 '24 19:01 aymuos15

I think this shall be fixed in the docs, @SkafteNicki thoughts?

Borda avatar Jan 31 '24 10:01 Borda