pytorch_influence_functions icon indicating copy to clipboard operation
pytorch_influence_functions copied to clipboard

Improve Numerical Stability of Function calc_loss

Open ZaydH opened this issue 4 years ago • 1 comments

Summary: We can improve the numerical stability/accuracy of the calc_loss method.

The current implementation uses the following:

def calc_loss(y, t):
    y = torch.nn.functional.log_softmax(y)
    loss = torch.nn.functional.nll_loss(
        y, t, weight=None, reduction='mean')
    return loss

PyTorch includes a single functional that is numerically more stable cross_entropy. It would also simplify the above code to:

def calc_loss(y, t):
    loss = torch.nn.functional.cross_entropy(y, t, weight=None, reduction="mean")
    return loss

ZaydH avatar Jan 23 '21 14:01 ZaydH

I thought cross_entropy just combines log_softmax and nll_loss.

pomonam avatar Jan 23 '21 21:01 pomonam