Efficient-Neural-Network-Bilibili icon indicating copy to clipboard operation
Efficient-Neural-Network-Bilibili copied to clipboard

关于distillation_loss中F.log_softmax和F.softmax不一致的问题

Open zhujiesuper opened this issue 4 years ago • 4 comments

def distillation(y, labels, teacher_scores, temp, alpha): print("teacher_scores",teacher_scores) return nn.KLDivLoss()(F.log_softmax(y / temp, dim=1), F.softmax(teacher_scores / temp, dim=1)) * ( temp * temp * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)

在次代码中,为啥F.log_softmax(y / temp, dim=1,用的log,而F.softmax(teacher_scores / temp, dim=1)不用log

zhujiesuper avatar Sep 08 '20 11:09 zhujiesuper

手动实现一个KL divergence

import torch.nn.functional as F

# input_logit: [batch, class_num]
# target_logit: [batch, class_num]
def kl_loss(input_logit, target_logit):
    prob = F.softmax(target_logit, dim=-1)
    kl = torch.sum(prob  * (F.log_softmax(target_logit, dim=-1)
                                  - F.log_softmax(input_logit, dim=-1)), 1)
    return torch.mean(kl)

或者可以这样

import torch.nn.functional as F

# input_log_sofmax: [batch, 1]
# target_sofmax: [batch, 1]
def kl_loss(input_log_sofmax, target_sofmax):
    prob = target_sofmax
    kl = torch.sum(prob  * (input_log_sofmax - torch.log(target_sofmax)), 1)
    return torch.mean(kl)

很明显,pytorch采纳了第二种方案。避免了重复计算。

如果传入的参数为(input_log_sofmax, target_log_sofmax),则需要将target_log_sofmax转化为target_sofmax计算概率

如果传入的参数为(input_sofmax, target_sofmax),则需要将input_sofmax转化为input_log_sofmax

具体细节请参考 https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html

mepeichun avatar Sep 11 '20 05:09 mepeichun

感谢,清晰易懂!太棒了

------------------ 原始邮件 ------------------ 发件人: "mepeichun/Efficient-Neural-Network-Bilibili" <[email protected]>; 发送时间: 2020年9月11日(星期五) 中午1:06 收件人: "mepeichun/Efficient-Neural-Network-Bilibili"<[email protected]>; 抄送: "朱一大爱"<[email protected]>;"Author"<[email protected]>; 主题: Re: [mepeichun/Efficient-Neural-Network-Bilibili] 提问! (#1)

手动实现一个KL divergence import torch.nn.functional as F # input_logit: [batch, class_num] # target_logit: [batch, class_num] def kl_loss(input_logit, target_logit): prob = F.softmax(target_logit, dim=-1) kl = torch.sum(prob * (F.log_softmax(target_logit, dim=-1) - F.log_softmax(input_logit, dim=-1)), 1) return torch.mean(kl)
或者可以这样 import torch.nn.functional as F # input_log_sofmax: [batch, class_num] # target_sofmax: [batch, class_num] def kl_loss(input_log_sofmax, target_sofmax): prob = target_sofmax kl = torch.sum(prob * (input_log_sofmax - torch.log(target_sofmax)), 1) return torch.mean(kl)
很明显,pytorch推荐使用第二种方案。 具体细节请参考 https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html

— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub, or unsubscribe.

zhujiesuper avatar Sep 11 '20 06:09 zhujiesuper

kl = torch.sum(prob  * (input_log_sofmax - torch.log(target_sofmax)), 1)

看了一个文章,感觉这一行代码应该改成下面的样子,不知道理解的对不对 loss函数之KLDivLoss

kl = torch.sum(prob  * (torch.log(target_sofmax)-input_log_sofmax), 1)

AlphaGogoo avatar Mar 04 '22 07:03 AlphaGogoo

对的

JevenM avatar Sep 22 '22 04:09 JevenM