ignite
ignite copied to clipboard
Metric with multiple input runs in an unexpected way.
❓ Questions/Help/Support
My customized loss requires two pairs of input:
class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb
def forward(self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)
When I try to log the loss with Loss metric:
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)
It will crash on line:
https://github.com/pytorch/ignite/blob/4825bb6ecd5f432f0423bf3d587fbce04b9d2d4f/ignite/metrics/metric.py#L308
because it treats all inputs as independent pair of y_pred and y, which is not what MyLoss need.
I dug into the source code I found https://github.com/pytorch/ignite/pull/2055 introduces a new feature, which causes this issue. So, what are the best practices for dealing with multiple input losses?
@lyhyl thanks for reporting this issue!
RIght now a workaround could be to replace the structure Tuple[torch.Tensor, torch.Tensor] by something non-iterable to prevent unrolling by Metric.
import torch
import torch.nn as nn
import torch.nn.functional as F
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss
class TargetsPair:
a: torch.Tensor
b: torch.Tensor
def __init__(self, a, b):
self.a = a
self.b = b
def __len__(self):
return len(self.a)
class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb
def forward(self, y_pred: TargetsPair, y_true: TargetsPair) -> torch.Tensor:
a_true, b_true = y_true.a, y_true.b
a_pred, b_pred = y_pred.a, y_pred.b
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)
def prepare_batch(batch, device, non_blocking):
return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))
class MyModel(nn.Module):
def forward(self, x):
return torch.rand(4, 1), torch.rand(4, 2)
model = MyModel()
def output_transform(output):
(a_pred, b_pred), (a_true, b_true) = output
return TargetsPair(a_pred, b_pred), TargetsPair(a_true, b_true)
device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss, output_transform=output_transform)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)
data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"]
In future, we may introduce a flag into Metric class to skip output unrolling and feed the output into update function.