audio-text_retrieval icon indicating copy to clipboard operation
audio-text_retrieval copied to clipboard

Understanding the NT-Xent loss function

Open Vedanshi-Shah opened this issue 1 year ago • 1 comments

Could you explain the significance of mask in the NT-Xent loss function?

mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(a2t.device)
mask_diag = mask.diag()
mask_diag = torch.diag_embed(mask_diag)
mask = mask ^ mask_diag

a2t_loss = - self.loss(a2t).masked_fill(mask, 0).diag().mean()
t2a_loss = - self.loss(t2a).masked_fill(mask, 0).diag().mean()

From what we have inferred, mask disregards the diagonal positive pairs, (i.e ( [i, i] ), but takes into account [i, j] (where i != j) positive pairs.

In the final a2t_loss calculation, we take the mean of diagonal values instead of taking the means of negative pairs. Since NT-Xent loss is supposed to account for the negative pairs similarity, how is that being calculated?

Vedanshi-Shah avatar Sep 02 '23 19:09 Vedanshi-Shah

Softmax is applied in self.loss().在 2023年9月2日,21:40,Vedanshi Shah @.***> 写道: Could you explain the significance of mask in the NT-Xent loss function? mask = labels.expand(n, n).eq(labels.expand(n, n).t()).to(a2t.device) mask_diag = mask.diag() mask_diag = torch.diag_embed(mask_diag) mask = mask ^ mask_diag

a2t_loss = - self.loss(a2t).masked_fill(mask, 0).diag().mean() t2a_loss = - self.loss(t2a).masked_fill(mask, 0).diag().mean()

From what we have inferred, mask disregards the diagonal positive pairs, (i.e ( [i, i] ), but takes into account [i, j] (where i != j) positive pairs. In the final a2t_loss calculation, we take the mean of diagonal values instead of taking the means of negative pairs. Since NT-Xent loss is supposed to account for the negative pairs similarity, how is that being calculated?

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you are subscribed to this thread.Message ID: @.***>

XinhaoMei avatar Sep 02 '23 20:09 XinhaoMei