torchmetrics
torchmetrics copied to clipboard
Allow arbitrary types for metric states
🚀 Feature
Metric states seem to be limited to torch.Tensor
or List[torch.Tensor]
.
In my usecase i want to store a dictionary as state. My dataset comprises of samples who can be assigned to different documents. In order to calculate macro metrics (calculate metrics per document and average) I want to store my metric states (e.g. true positives, false positives, etc.) as a dictionary. Here is some pseudocode:
class MyMetric(Metric):
def __init__(self, dist_sync_on_step: bool = False):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.add_state("statistics", default=defaultdict(lambda: defaultdict(float)), dist_reduce_fx=None)
def update(
self,
predictions: torch.Tensor,
targets: torch.Tensor,
document_ids: List
):
predictions = predictions.bool()
targets = targets.bool()
tps = predictions * targets
tns = predictions.logical_not() * targets.logical_not()
fps = predictions * targets.logical_not()
fns = predictions.logical_not() * targets
for id_, tp, tn, fp, fn in zip(document_ids, tps, tns, fps, fns):
self.statistics[id_]['tp'] += tp.float().item()
self.statistics[id_]['tn'] += tn.float().item()
self.statistics[id_]['fp'] += fp.float().item()
self.statistics[id_]['fn'] += fn.float().item()
def compute(self):
...
Unfortunately the above code is not allowed. Each metric state has to be a torch.Tensor
or a List[torch.Tensor]
.
That means normal float values or numpy arrays cannot be used as metrics either. Is there a particular reason for that?
Hi! thanks for your contribution!, great first issue!
Hey,
Are there any updates on this matter? After checking metric.py
I don't see a particular reason for this restriction but maybe I overlooked something.
Also, the reset
method could simply reset each state variable to the user defined default instead of checking the type to distinguish between List
and torch.Tensor
. Regarding reduction logic one can keep the current approach and if the state type is something custom (e.g. a Dict
, etc.) the synced state is simply a List[Dict]
and the user is in charge of providing a custome reduce fn to combine the different states across processes.
@LarsHill, I agree that this is an important question to ask. I am also encountering a use case where I need to track my custom Metric
's state with a dictionary. Any thoughts on this query, @SkafteNicki or @Borda?