torchmetrics
torchmetrics copied to clipboard
Potential Unsupported version of DICE.
Bug description
https://lightning.ai/docs/torchmetrics/stable/classification/dice.html
Here, under average parameter of the dice score, there is the option of 'weighted' mentioned. But that triggers an error while using it.
How to reproduce the bug
!pip install torchmetrics
from torch import tensor
from torchmetrics.classification import Dice
preds = tensor([2, 0, 2, 1])
target = tensor([1, 1, 2, 0])
dice = Dice(average='weighted')
dice(preds, target)
Error messages and logs
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-4-16f87eeaf06b>](https://localhost:8080/#) in <cell line: 6>()
4 preds = tensor([2, 0, 2, 1])
5 target = tensor([1, 1, 2, 0])
----> 6 dice = Dice(average='weighted')
7 dice(preds, target)
[/usr/local/lib/python3.10/dist-packages/torchmetrics/classification/dice.py](https://localhost:8080/#) in __init__(self, zero_division, num_classes, threshold, average, mdmc_average, ignore_index, top_k, multiclass, **kwargs)
158 allowed_average = ("micro", "macro", "samples", "none", None)
159 if average not in allowed_average:
--> 160 raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")
161
162 _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
ValueError: The `average` has to be one of ('micro', 'macro', 'samples', 'none', None), got weighted.
Environment
Google Colab -> (Hence the pip install)
More info
https://colab.research.google.com/drive/1GuBr6kI0ypj0sUDaqGQ8dN0Td-sbYWGI?usp=sharing
Hi! thanks for your contribution!, great first issue!
I see the issue
Its in the doc string, but not within the code. I guess weighted is not supported (?)
https://github.com/Lightning-AI/torchmetrics/blob/a68455afb9041d1d32c1d6546897fee416abdc41/src/torchmetrics/classification/dice.py#L71
https://github.com/Lightning-AI/torchmetrics/blob/a68455afb9041d1d32c1d6546897fee416abdc41/src/torchmetrics/classification/dice.py#L158
Looks like most of the weighted code was removed
https://github.com/Lightning-AI/torchmetrics/commit/2509448389c2c5bdd3305721e74cedd6bd26c57b
Great Find. Thank you. Will keep the issue open for the time-being.
I think this shall be fixed in the docs, @SkafteNicki thoughts?