contrastive_loss
contrastive_loss copied to clipboard
Question about logit_masks
Thanks for your sharing. I have some trouble in logits_mask. `
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
#
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
` logits_mask is used to filter out the negative pairs, why not the ~mask but one matrix with the diagonals 0 and the others are 1 ?