torchmetrics
torchmetrics copied to clipboard
Fix/multiclass recall macro avg ignore index
What does this PR do?
Fixes #2441
Details
- [x] Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
- [x] Did you read the contributor guideline, Pull Request section?
- [ ] Did you make sure to update the docs?
- [x] Did you write any new necessary tests?
Did you have fun?
Yes
Issue:
- The root of the problem seems to be that the
ignore_indexinformation is not being properly propagated to the final averaging step i.e. the_adjust_weights_safe_dividefunction doesn't know that which class should be ignored.
Solution:
- To address this issue, I updated the code to ensure that the
ignore_indexinformation is preserved throughout the entire process, making sure it is correctly passed through all intermediate steps up to the final averaging stage i.e._adjust_weights_safe_dividefunction . - Updated the
_adjust_weights_safe_dividefunction to accept an additionalignore_indexparameter, which is passed through the_precision_recall_reducefunction, called in thecomputemethod of theMulticlassRecallclass. This change adjusts the weights in the_adjust_weights_safe_dividefunction, setting the weight of the ignored class to 0.
📚 Documentation preview 📚: https://torchmetrics--2710.org.readthedocs.build/en/2710/
looks good, can we add also test for this case...
Sure
@Borda What do I have to modify?
@rittik9 mind checking the changed docstest values and whether it is correct?
@rittik9 mind checking the changed docstest values and whether it is correct?
@Borda I checked the tests I wrote, I believe they are correct. I am planning to take a detailed look on tests which are failing later.
Any update on when this PR could be merged? It would really help if we could update from the 0.9.3 version once this fix is merged.
Any update on when this PR could be merged? It would really help if we could update from the 0.9.3 version once this fix is merged.
the tests/doctests need to be fixed, are you interested in submitting a suggestion on what else needs to be fixed/chnaged?
Just to chime in, I think this issue is present in pretty much all metrics that make use of _adjust_weights_safe_divide.
I see this PR fixes, some of them, but others, such as JaccardIndex are left as is.
Shall we then fix the unittest, or was it in the meantime already resolved with sklearn?
I think we gotta fix the unit tests first