contrastive_loss icon indicating copy to clipboard operation
contrastive_loss copied to clipboard

Question about logit_masks

Open Alva-2020 opened this issue 2 years ago • 0 comments

Thanks for your sharing. I have some trouble in logits_mask. `

    # mask-out self-contrast cases
    logits_mask = torch.scatter(
        torch.ones_like(mask),
        1,
        torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
        0
    )

    mask = mask * logits_mask

    # compute log_prob
    exp_logits = torch.exp(logits) * logits_mask
    #
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

` logits_mask is used to filter out the negative pairs, why not the ~mask but one matrix with the diagonals 0 and the others are 1 ?

Alva-2020 avatar Apr 28 '22 14:04 Alva-2020