ConSERT icon indicating copy to clipboard operation
ConSERT copied to clipboard

NT-Xent损失函数

Open wsh2836741 opened this issue 4 years ago • 4 comments

感谢您非常出色的工作。有个小问题想问一下,看过代码后发现loss损失函数的实现和nt-xent损失函数的公式不太一致? 44C050B4-0270-4956-BE10-665840F167C3

wsh2836741 avatar Jul 27 '21 12:07 wsh2836741

您好,我们对比损失的实现代码在这里,实现上主要参考的是SimCLR的官方实现,具体含义如下:

            if hidden_norm:  # 对句向量做L2归一化,归一化后向量点积的结果即为原向量的cosine相似度
                hidden1 = torch.nn.functional.normalize(hidden1, p=2, dim=-1)
                hidden2 = torch.nn.functional.normalize(hidden2, p=2, dim=-1)

            hidden1_large = hidden1
            hidden2_large = hidden2

            # 下面的batch_size用N来表示
            # labels 为 [0, 1, ..., N - 1]
            labels = torch.arange(0, batch_size).to(device=hidden1.device)

            # masks 为 N * N 的一个单位矩阵,对角元素为1.0,其余元素为0.0
            # 这个masks对应上面公式中的 [k != i],用来将自身mask掉
            masks = torch.nn.functional.one_hot(torch.arange(0, batch_size), num_classes=batch_size).to(device=hidden1.device, dtype=torch.float)

            # 下面的logits_aa, logits_ab, logits_bb, logits_ba用来生成一个2N * 2N相似度矩阵,只不过分成了四个N*N的子矩阵。
            # 其中,logits_aa和logits_bb,因为其对角线包含自身和自身的乘积(对应上面公式中k==i的那一项,其结果将永远为1),所以
            # 将这部分数值替换为负无穷,这样在计算exp之后,其结果近似为0,也就可以近似认为这一项不存在
            logits_aa = torch.matmul(hidden1, hidden1_large.transpose(0, 1)) / temperature  # shape (bsz, bsz)
            logits_aa = logits_aa - masks * LARGE_NUM
            logits_bb = torch.matmul(hidden2, hidden2_large.transpose(0, 1)) / temperature  # shape (bsz, bsz)
            logits_bb = logits_bb - masks * LARGE_NUM
            logits_ab = torch.matmul(hidden1, hidden2_large.transpose(0, 1)) / temperature  # shape (bsz, bsz)
            logits_ba = torch.matmul(hidden2, hidden1_large.transpose(0, 1)) / temperature  # shape (bsz, bsz)

            # torch.cat([logits_ab, logits_aa]中,每一行包含了公式中分母的2N - 1项(因为mask掉了自身那一项),而分子那一项
            # 则通过labels体现,这里可以通过控制logits子矩阵的拼接,使得相同样本、不同view生成的句向量的相似度计算放在前
            # 面,这样其 labels 正好为 [0, 1, ..., N - 1]
            loss_a = torch.nn.functional.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), labels)  # 对view1的N个句向量计算交叉熵损失
            loss_b = torch.nn.functional.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), labels)   # 对view2的N个句向量计算交叉熵损失,原理与上面一致,只不过调换了子矩阵的顺序

            loss = loss_a + loss_b  # 2N个样本的交叉熵损失之和
            return loss

综合来看应该是和论文中的公式一致的,只不过论文中的公式其实只是对应于一个batch内单个样本的分类,还需要在外面加一个2N的求和。

yym6472 avatar Jul 27 '21 13:07 yym6472

太详细了,学习到了,非常感谢。

wsh2836741 avatar Jul 28 '21 02:07 wsh2836741

loss = loss_a + loss_b # 2N个样本的交叉熵损失之和

所以最后为什么没除个2N呢再?不除的话batch_size会影响loss的

qishibo avatar Dec 10 '21 08:12 qishibo

loss = loss_a + loss_b # 2N个样本的交叉熵损失之和

所以最后为什么没除个2N呢再?不除的话batch_size会影响loss的

loss_a和loss_b应该都是平均之后的,torch.nn.functional.cross_entropy默认reduction='mean',这里确实应该再除2。

yym6472 avatar Dec 10 '21 11:12 yym6472