warp-ctc icon indicating copy to clipboard operation
warp-ctc copied to clipboard

How to add focal loss in CTC?

Open Xujianzhong opened this issue 5 years ago • 8 comments

I want to introduce focal loss in CTC loss in the engineering, and I don't know where to add it. I hope a god can help me. Thanks !

Xujianzhong avatar Jun 03 '19 12:06 Xujianzhong

me either

xingchensong avatar Jun 04 '19 02:06 xingchensong

@stephen-song (为了方便交流,请容许我使用中文) 目前ctc里面loss用的交叉熵方式:L=-log(p),我希望修改这个loss的计算方式为:L=(1-p)log(p)这样的形式,请问如何修改代码。 不胜感激!

Xujianzhong avatar Jun 05 '19 09:06 Xujianzhong

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

knowledge3 avatar Sep 28 '19 01:09 knowledge3

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

大佬你好,请问可以发一下你的代码吗?tensorflow版的,对于α还有一些参数存疑,还望解答

Tangzixia avatar Nov 29 '19 02:11 Tangzixia

写一个Lambda函数,先获取ctc loss值,然后对e求-ctc loss次方得到p值,然后就可以得到你想要的新loss了。tf或者keras实现过,pytorch还没尝试

请问focal loss 加到 ctc loss上有效果吗?

ChChwang avatar May 19 '21 12:05 ChChwang

根据上面讨论想到大概的代码(未验证,最近会跑跑看):

criterion = torch.nn.CTCLoss()
# fm shape: [N, B, C] 
output = F.log_softmax(fm, dim=2)
ctc_loss = criterion(output, target, pred_size, target_length)

p = torch.exp(-ctc_loss)
focal_loss = -(1-p) * torch.log(p)

blueardour avatar Sep 27 '21 08:09 blueardour

[update] no benefit in my training with above code.

blueardour avatar Sep 28 '21 06:09 blueardour

这是paddle的:

import paddle from paddle import nn

class CTCLoss(nn.Layer):

  def __init__(self, use_focal_loss=False, **kwargs):
      super(CTCLoss, self).__init__()
      self.loss_func = nn.CTCLoss(blank=0, reduction='none')
      self.use_focal_loss = use_focal_loss

  def forward(self, predicts, batch):
      if isinstance(predicts, (list, tuple)):
          predicts = predicts[-1]
      predicts = predicts.transpose((1, 0, 2))
      N, B, _ = predicts.shape
      preds_lengths = paddle.to_tensor(
          [N] * B, dtype='int64', place=paddle.CPUPlace())
      labels = batch[1].astype("int32")
      label_lengths = batch[2].astype('int64')
      loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
      if self.use_focal_loss:
          weight = paddle.exp(-loss)
          weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
          weight = paddle.square(weight)
          loss = paddle.multiply(loss, weight)
      loss = loss.mean()
      return {'loss': loss}

huangxin168 avatar Nov 26 '22 17:11 huangxin168