ConSERT
ConSERT copied to clipboard
NT-Xent损失函数
感谢您非常出色的工作。有个小问题想问一下,看过代码后发现loss损失函数的实现和nt-xent损失函数的公式不太一致?

您好,我们对比损失的实现代码在这里,实现上主要参考的是SimCLR的官方实现,具体含义如下:
if hidden_norm: # 对句向量做L2归一化,归一化后向量点积的结果即为原向量的cosine相似度
hidden1 = torch.nn.functional.normalize(hidden1, p=2, dim=-1)
hidden2 = torch.nn.functional.normalize(hidden2, p=2, dim=-1)
hidden1_large = hidden1
hidden2_large = hidden2
# 下面的batch_size用N来表示
# labels 为 [0, 1, ..., N - 1]
labels = torch.arange(0, batch_size).to(device=hidden1.device)
# masks 为 N * N 的一个单位矩阵,对角元素为1.0,其余元素为0.0
# 这个masks对应上面公式中的 [k != i],用来将自身mask掉
masks = torch.nn.functional.one_hot(torch.arange(0, batch_size), num_classes=batch_size).to(device=hidden1.device, dtype=torch.float)
# 下面的logits_aa, logits_ab, logits_bb, logits_ba用来生成一个2N * 2N相似度矩阵,只不过分成了四个N*N的子矩阵。
# 其中,logits_aa和logits_bb,因为其对角线包含自身和自身的乘积(对应上面公式中k==i的那一项,其结果将永远为1),所以
# 将这部分数值替换为负无穷,这样在计算exp之后,其结果近似为0,也就可以近似认为这一项不存在
logits_aa = torch.matmul(hidden1, hidden1_large.transpose(0, 1)) / temperature # shape (bsz, bsz)
logits_aa = logits_aa - masks * LARGE_NUM
logits_bb = torch.matmul(hidden2, hidden2_large.transpose(0, 1)) / temperature # shape (bsz, bsz)
logits_bb = logits_bb - masks * LARGE_NUM
logits_ab = torch.matmul(hidden1, hidden2_large.transpose(0, 1)) / temperature # shape (bsz, bsz)
logits_ba = torch.matmul(hidden2, hidden1_large.transpose(0, 1)) / temperature # shape (bsz, bsz)
# torch.cat([logits_ab, logits_aa]中,每一行包含了公式中分母的2N - 1项(因为mask掉了自身那一项),而分子那一项
# 则通过labels体现,这里可以通过控制logits子矩阵的拼接,使得相同样本、不同view生成的句向量的相似度计算放在前
# 面,这样其 labels 正好为 [0, 1, ..., N - 1]
loss_a = torch.nn.functional.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), labels) # 对view1的N个句向量计算交叉熵损失
loss_b = torch.nn.functional.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), labels) # 对view2的N个句向量计算交叉熵损失,原理与上面一致,只不过调换了子矩阵的顺序
loss = loss_a + loss_b # 2N个样本的交叉熵损失之和
return loss
综合来看应该是和论文中的公式一致的,只不过论文中的公式其实只是对应于一个batch内单个样本的分类,还需要在外面加一个2N的求和。
太详细了,学习到了,非常感谢。
loss = loss_a + loss_b # 2N个样本的交叉熵损失之和
所以最后为什么没除个2N呢再?不除的话batch_size会影响loss的
loss = loss_a + loss_b # 2N个样本的交叉熵损失之和
所以最后为什么没除个2N呢再?不除的话batch_size会影响loss的
loss_a和loss_b应该都是平均之后的,torch.nn.functional.cross_entropy默认reduction='mean',这里确实应该再除2。