pytorch-multi-class-focal-loss
pytorch-multi-class-focal-loss copied to clipboard
Implementation is incorrect
ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight)
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()
While ce_loss is correctly -weight[y_t]log(p[y_t]), pt is not p[y_t] as you would expect. Instead, it is e^(weight[y_t])p[y_t], which is incorrect.
Also, the reduction is performed at the CE step so pt isn't a tensor of the probabilities of individual spatial positions, it's a scalar.
ce_loss = F.cross_entropy(input, target,reduction=self.reduction,weight=self.weight) pt = torch.exp(-ce_loss) focal_loss = ((1 - pt) ** self.gamma * ce_loss).mean()While
ce_lossis correctly-weight[y_t]log(p[y_t]),ptis notp[y_t]as you would expect. Instead, it ise^(weight[y_t])p[y_t], which is incorrect.Also, the reduction is performed at the CE step so
ptisn't a tensor of the probabilities of individual spatial positions, it's a scalar.
Hi @ctensmeyer , can you share the correct implementation code?