Class-balanced-loss-pytorch
Class-balanced-loss-pytorch copied to clipboard
why modulator?
Hi, i'm interested in your work! Now i have problem, why your implement code of focal loss use "modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))" ?
it's just a transfer of the formula,but "labels" should not show in the code.i think it is modulator = torch.exp(-gamma * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))
it's the formula of focal loss, and "labels" should be reserved.
notice that y' = 1 / (1 + torch.exp(-1.0 * logits)), and both logits and labels are matrixs.
for where labels is 0, modulator = (1 / (1 + torch.exp(-1.0 * logits))) ** gamma, and for where labels is 1, modulator = (torch.exp(-1.0 * logits) / (1 + torch.exp(-1.0 * logits)) ** gamma = (1 / (1 + torch.exp(logits)) ** gamma
matching the formula well.