mmsegmentation-distiller
mmsegmentation-distiller copied to clipboard
cwd loss 和论文中的公式不一致
hi,请问作者 在论文里你用到了KL去学习Teacher的输出。 Equation (4)
在你的代码里
loss = torch.sum(-softmax_pred_T * logsoftmax(preds_S.view(-1, W * H) / self.tau)) * ( self.tau**2)
你使用的这个公式去计算KD loss。
但是标准的KL Loss应该是这样
kl_loss = torch.sum(softmax_pred_T * (logsoftmax(preds_T.view(-1, W * H) / self.tau) - logsoftmax(preds_S.view(-1, W * H) / self.tau))) * ( self.tau**2)
是论文中哪里做了省略或者参数的近似么?
如果把 kl 的公式展开,和 student 无关的项省掉后(teacher 部分不影响梯度更新,只影响数值),就是代码里的形式
如果把 kl 的公式展开,和 student 无关的项省掉后(teacher 部分不影响梯度更新,只影响数值),就是代码里的形式
明白了,多谢作者。
谢谢,我也有相同的疑惑,看到这里就明白了