ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Metric with multiple input runs in an unexpected way.

Open lyhyl opened this issue 2 years ago • 1 comments

❓ Questions/Help/Support

My customized loss requires two pairs of input:

class MyLoss(nn.Module):
    def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
        super().__init__()
        self.ca = ca
        self.cb = cb

    def forward(self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        a_true, b_true = y_true
        a_pred, b_pred = y_pred
        return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)

When I try to log the loss with Loss metric:

loss = MyLoss(0.5, 1.0)
metrics = {
    "Loss": Loss(loss)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)

It will crash on line: https://github.com/pytorch/ignite/blob/4825bb6ecd5f432f0423bf3d587fbce04b9d2d4f/ignite/metrics/metric.py#L308 because it treats all inputs as independent pair of y_pred and y, which is not what MyLoss need.

I dug into the source code I found https://github.com/pytorch/ignite/pull/2055 introduces a new feature, which causes this issue. So, what are the best practices for dealing with multiple input losses?

lyhyl avatar May 09 '23 14:05 lyhyl

@lyhyl thanks for reporting this issue!

RIght now a workaround could be to replace the structure Tuple[torch.Tensor, torch.Tensor] by something non-iterable to prevent unrolling by Metric.

import torch
import torch.nn as nn
import torch.nn.functional as F

from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss


class TargetsPair:
    a: torch.Tensor
    b: torch.Tensor

    def __init__(self, a, b):
        self.a = a
        self.b = b
    
    def __len__(self):
        return len(self.a)


class MyLoss(nn.Module):
    def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
        super().__init__()
        self.ca = ca
        self.cb = cb

    def forward(self, y_pred: TargetsPair, y_true: TargetsPair) -> torch.Tensor:
        a_true, b_true = y_true.a, y_true.b
        a_pred, b_pred = y_pred.a, y_pred.b
        return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)


def prepare_batch(batch, device, non_blocking):
    return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))


class MyModel(nn.Module):
    
    def forward(self, x):
        return torch.rand(4, 1), torch.rand(4, 2)


model = MyModel()


def output_transform(output):
    (a_pred, b_pred), (a_true, b_true) = output
    return TargetsPair(a_pred, b_pred), TargetsPair(a_true, b_true)


device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
    "Loss": Loss(loss, output_transform=output_transform)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)


data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"]

In future, we may introduce a flag into Metric class to skip output unrolling and feed the output into update function.

vfdev-5 avatar May 09 '23 15:05 vfdev-5