mmsegmentation-distiller icon indicating copy to clipboard operation
mmsegmentation-distiller copied to clipboard

cwd loss 和论文中的公式不一致

Open Crazod opened this issue 3 years ago • 3 comments

hi,请问作者 在论文里你用到了KL去学习Teacher的输出。 Equation (4)

2021-11-26 12-03-51屏幕截图 在你的代码里 loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S.view(-1, W * H) / self.tau)) * ( self.tau**2) 你使用的这个公式去计算KD loss。 但是标准的KL Loss应该是这样 kl_loss = torch.sum(softmax_pred_T * (logsoftmax(preds_T.view(-1, W * H) / self.tau) - logsoftmax(preds_S.view(-1, W * H) / self.tau))) * ( self.tau**2) 是论文中哪里做了省略或者参数的近似么?

Crazod avatar Nov 26 '21 04:11 Crazod

如果把 kl 的公式展开,和 student 无关的项省掉后(teacher 部分不影响梯度更新,只影响数值),就是代码里的形式

pppppM avatar Nov 26 '21 04:11 pppppM

如果把 kl 的公式展开,和 student 无关的项省掉后(teacher 部分不影响梯度更新,只影响数值),就是代码里的形式

明白了,多谢作者。

Crazod avatar Nov 26 '21 04:11 Crazod

谢谢,我也有相同的疑惑,看到这里就明白了

HaoKun-Li avatar Jun 27 '22 09:06 HaoKun-Li