BERT-NER-Pytorch icon indicating copy to clipboard operation
BERT-NER-Pytorch copied to clipboard

DiceLoss 这个公式写的对吗,怎么理解呢

Open LLLLLLoki opened this issue 5 years ago • 1 comments

def forward(self,input, target):
    '''
    input: [N, C]
    target: [N, ]
    '''
    prob = torch.softmax(input, dim=1)
    prob = torch.gather(prob, dim=1, index=target.unsqueeze(1))
    dsc_i = 1 - ((1 - prob) * prob) / ((1 - prob) * prob + 1)
    dice_loss = dsc_i.mean()
    return dice_loss

论文中是 DSC(Xi)= (2(1-p)p*y + r)/((1-p)p + y +r)

LLLLLLoki avatar May 20 '20 07:05 LLLLLLoki

def forward(self,input, target):
    '''
    input: [N, C]
    target: [N, ]
    '''
    prob = torch.softmax(input, dim=1)
    prob = torch.gather(prob, dim=1, index=target.unsqueeze(1))
    dsc_i = 1 - ((1 - prob) * prob) / ((1 - prob) * prob + 1)
    dice_loss = dsc_i.mean()
    return dice_loss

论文中是 DSC(Xi)= (2(1-p)p*y + r)/((1-p)p + y +r)

代码中没看到用diceloss呀,只有focalloss 和label smoothing.

lvjiujin avatar Aug 31 '21 05:08 lvjiujin