pytorch_influence_functions
pytorch_influence_functions copied to clipboard
Improve Numerical Stability of Function calc_loss
Summary: We can improve the numerical stability/accuracy of the calc_loss method.
The current implementation uses the following:
def calc_loss(y, t):
y = torch.nn.functional.log_softmax(y)
loss = torch.nn.functional.nll_loss(
y, t, weight=None, reduction='mean')
return loss
PyTorch includes a single functional that is numerically more stable cross_entropy. It would also simplify the above code to:
def calc_loss(y, t):
loss = torch.nn.functional.cross_entropy(y, t, weight=None, reduction="mean")
return loss
I thought cross_entropy just combines log_softmax and nll_loss.