lightning-flash icon indicating copy to clipboard operation
lightning-flash copied to clipboard

Support ClasswiseWrapper Metrics for Classification Tasks

Open newzealandpaul opened this issue 1 year ago • 0 comments

🚀 Feature

Currently torchmetrics ClasswiseWrapper, which allows for per-class metrics, is not supported by Lightning.

Motivation

Per-class metrics are essential for many classification tasks, to give insight into model performance.

Pitch

Currently passing ClasswiseWrapper() metrics when creating a new instance of a Lightning model causes an error in flash/core/model.py:373 because ClasswiseWrapper objects do not have a _forward_cache attribute. Fixing that, causes an error in trainer/connectors/logger_connector/result.py:548 as it expects a tensor not a dict of tensors.

Users would expect that torchmetric features are natively supported.

newzealandpaul avatar Dec 08 '22 02:12 newzealandpaul