BERT-NER-Pytorch
BERT-NER-Pytorch copied to clipboard
DiceLoss 这个公式写的对吗,怎么理解呢
def forward(self,input, target):
'''
input: [N, C]
target: [N, ]
'''
prob = torch.softmax(input, dim=1)
prob = torch.gather(prob, dim=1, index=target.unsqueeze(1))
dsc_i = 1 - ((1 - prob) * prob) / ((1 - prob) * prob + 1)
dice_loss = dsc_i.mean()
return dice_loss
论文中是 DSC(Xi)= (2(1-p)p*y + r)/((1-p)p + y +r)
def forward(self,input, target): ''' input: [N, C] target: [N, ] ''' prob = torch.softmax(input, dim=1) prob = torch.gather(prob, dim=1, index=target.unsqueeze(1)) dsc_i = 1 - ((1 - prob) * prob) / ((1 - prob) * prob + 1) dice_loss = dsc_i.mean() return dice_loss论文中是 DSC(Xi)= (2(1-p)p*y + r)/((1-p)p + y +r)
代码中没看到用diceloss呀,只有focalloss 和label smoothing.