DINO icon indicating copy to clipboard operation
DINO copied to clipboard

About sigmoid_focal_loss

Open BinhuiXie opened this issue 2 years ago • 5 comments

Hello, guys. Well done!

I have a quick question about sigmoid_focal_loss as follows: https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L379

Why is the third dimension of target_classes_onehot one more than that of src_logits? Does the extra one dimension represent the "no object"?

Thanks in advance.

BinhuiXie avatar Jul 25 '22 03:07 BinhuiXie

Yes, the extra one dimension does represent the "no object".

HaoZhang534 avatar Jul 25 '22 06:07 HaoZhang534

thanks!

BinhuiXie avatar Jul 25 '22 06:07 BinhuiXie

Sorry, another question.

https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L384

Why is sigmoid_focal_loss (binary cross entropy with logits) usually used in object detection? What are the advantages? Could we use standard cross-entropy with softmax?

BinhuiXie avatar Jul 25 '22 07:07 BinhuiXie

In my understanding, sigmoid is more suitable for multi-class classification. When the model is not sure which of two classes an object belongs to, it can predict both so that one of them is correct.

HaoZhang534 avatar Jul 25 '22 19:07 HaoZhang534

That makes sense.

In fact, I tried softmax_focal_loss following sigmoid_focal_loss as follows:

def softmax_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: An integer tensor. Stores the class label for each element in inputs.
        alpha: (optional) Weighting factor in range (0,1) to balance.
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    _EPSILON = 1e-4

    prob = F.softmax(inputs, dim=2)
    prob = prob.gather(-1, targets.unsqueeze(-1))  

    logpt = torch.log(torch.clamp(prob, _EPSILON, 1 - _EPSILON))
    focal_modulation = (1 - prob) ** gamma

    loss = -alpha * focal_modulation * logpt

    return loss.mean(1).sum() / num_boxes

And then modify the following lines https://github.com/IDEACVR/DINO/blob/67bbcd97ef30a48cf343b7b0f3ad9ea0795b6fcd/models/dino/dino.py#L383-L385

    target_classes_onehot = target_classes_onehot[:,:,:-1] 
    loss_ce = softmax_focal_loss(src_logits, target_classes_onehot.argmax(-1), num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] 
    losses = {'loss_ce': loss_ce} 

However, the performance drops considerably.

Could you give some bits of advice? Thanks!

BinhuiXie avatar Jul 26 '22 05:07 BinhuiXie