对比损失计算问题
您好,我想请教一下,我把您的计算对比损失的代码拿到我的模型中,得到的损失值是nan?这是什么原因呢?我的代码如下: def get_seq_hidden(self,encoder_out,encoder_padding_mask): #audio_token=[B,T1,H],text_token=[B,T2,H] encoder_padding_mask = (~encoder_padding_mask).float() denom = encoder_padding_mask.sum(dim=1,keepdim=True).clamp(min=1e-6) seq_hidden = (encoder_out * encoder_padding_mask).sum(dim=1) / denom return seq_hidden
def compute_contrastive_loss(self,audio_token,audio_encoder_padding_mask,text_token,text_encoder_padding_mask,contrastive_temperature): audio_seq_hidden = self.get_seq_hidden(audio_token,audio_encoder_padding_mask) text_seq_hidden = self.get_seq_hidden(text_token,text_encoder_padding_mask) batch_size,hidden_size = audio_seq_hidden.size() audio_seq_hidden = F.normalize(audio_seq_hidden,p=2,dim=-1) text_seq_hidden = F.normalize(text_seq_hidden,dim=-1) logits = F.cosine_similarity(audio_seq_hidden.expand((batch_size,batch_size,hidden_size)), text_seq_hidden.expand((batch_size,batch_size,hidden_size)).transpose(0,1), dim=-1)
logits /= contrastive_temperature
loss_audio = -torch.nn.LogSoftmax(0)(logits).diag()
loss_text = -torch.nn.LogSoftmax(1)(logits).diag()
loss = loss_audio + loss_text
return loss.sum()
期待您的回复。