torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Allow arbitrary types for metric states

Open LarsHill opened this issue 2 years ago • 6 comments

🚀 Feature

Metric states seem to be limited to torch.Tensor or List[torch.Tensor].

In my usecase i want to store a dictionary as state. My dataset comprises of samples who can be assigned to different documents. In order to calculate macro metrics (calculate metrics per document and average) I want to store my metric states (e.g. true positives, false positives, etc.) as a dictionary. Here is some pseudocode:

class MyMetric(Metric):

    def __init__(self, dist_sync_on_step: bool = False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)

        self.add_state("statistics", default=defaultdict(lambda: defaultdict(float)), dist_reduce_fx=None)

    def update(
            self,
            predictions: torch.Tensor,
            targets: torch.Tensor,
            document_ids: List
    ):
        predictions = predictions.bool()
        targets = targets.bool()

        tps = predictions * targets
        tns = predictions.logical_not() * targets.logical_not()
        fps = predictions * targets.logical_not()
        fns = predictions.logical_not() * targets

        for id_, tp, tn, fp, fn in zip(document_ids, tps, tns, fps, fns):
            self.statistics[id_]['tp'] += tp.float().item()
            self.statistics[id_]['tn'] += tn.float().item()
            self.statistics[id_]['fp'] += fp.float().item()
            self.statistics[id_]['fn'] += fn.float().item()

    def compute(self):
            ...

Unfortunately the above code is not allowed. Each metric state has to be a torch.Tensor or a List[torch.Tensor]. That means normal float values or numpy arrays cannot be used as metrics either. Is there a particular reason for that?

LarsHill avatar Apr 26 '22 14:04 LarsHill

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Apr 26 '22 14:04 github-actions[bot]

Hey,

Are there any updates on this matter? After checking metric.py I don't see a particular reason for this restriction but maybe I overlooked something. Also, the reset method could simply reset each state variable to the user defined default instead of checking the type to distinguish between List and torch.Tensor. Regarding reduction logic one can keep the current approach and if the state type is something custom (e.g. a Dict, etc.) the synced state is simply a List[Dict] and the user is in charge of providing a custome reduce fn to combine the different states across processes.

LarsHill avatar Jun 17 '22 14:06 LarsHill

@LarsHill, I agree that this is an important question to ask. I am also encountering a use case where I need to track my custom Metric's state with a dictionary. Any thoughts on this query, @SkafteNicki or @Borda?

amorehead avatar Sep 18 '22 01:09 amorehead