pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

InfoNCE errors

Open davidireland-iso opened this issue 1 year ago • 3 comments

I believe there are some issues with the InfoNCE loss. After stepping through the code, the denominator is calculated only for negatives in the batch (it should be similar to SupConLoss and calculated for all items in the batch except with itself). Also there appears to be multiple positives used in the numerator, unless I am mistaken, when with InfoNCE only a single positive should be used -- I think this makes it somewhere between the 'real' InfoNCE and SupConLoss?

davidireland-iso avatar Jun 24 '24 14:06 davidireland-iso

Sorry for the delay. I believe the denominator does include all negatives plus the positive, for each positive, as you can see here:

https://github.com/KevinMusgrave/pytorch-metric-learning/blob/adfb78ccb275e5bd7a84d2693da41ff3002eed32/src/pytorch_metric_learning/losses/ntxent_loss.py#L34

As for the number of positives, the loss is computed for each anchor-positive pair separately. So the numerator always consists of a single positive pair. See the section here titled "How exactly is the NTXentLoss computed?": https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

I also have tests that confirm this. You can view the relevant part here: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/adfb78ccb275e5bd7a84d2693da41ff3002eed32/tests/losses/test_ntxent_loss.py#L103-L158

The relevant variables are numeratorA, denominatorA, curr_lossA, total_lossA, and obtained_losses[0].

KevinMusgrave avatar Jun 28 '24 18:06 KevinMusgrave

Thanks for getting back to me. I missed the addition of the numerator when stepping through!

I understand that the numerator is calculated for a single anchor-positive pair at a time, and that it is calculated for each positive per anchor in the batch, but therein lies my problem -- I was under the impression that the canonical infoNCE would compute the loss for an anchor just for a single positive per anchor in the batch (if there are multiple positives then it can be chosen arbitrarily). I'm not sure how much difference this makes in practice, but under the current implementation I can't see how there would be much of a difference between InfoNCE and the SupConLoss, other than we are not averaging the log(prob) by the number of positives for that anchor in the batch. Is that line of thinking correct or am I missing something else?

davidireland-iso avatar Jun 28 '24 19:06 davidireland-iso

compute the loss for an anchor just for a single positive per anchor in the batch (if there are multiple positives then it can be chosen arbitrarily)

To me it doesn't make sense to use only 1 positive when there are multiple. If the user wants 1 positive, they should construct the batch that way. Alternatively you can write a custom reducer, since the NTXentLoss returns a loss per positive pair. So you could write a reducer that only keeps the loss for 1 positive pair per anchor.

I can't remember the exact difference between NTXentLoss and SupConLoss, but yes the differences are subtle. That's probably why I test the SupConLoss in the same file as the NTXentLoss file: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/losses/test_ntxent_loss.py

Similarly, there's another paper that proposes a loss that turns out to be a slight variation of the NTXentLoss as well. See: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/664

KevinMusgrave avatar Jun 29 '24 13:06 KevinMusgrave

I agree, it doesn't make sense, which is why I would assume that the SupConLoss loss was formalised in their paper. Anyway, thanks for the discussion, I think this has cleared up my confusion around the implementation!

davidireland-iso avatar Jul 01 '24 08:07 davidireland-iso