CSI
CSI copied to clipboard
Some questions about Supervised_NT_xent
Excuse me, I recently paid attention to this paper.
when I use the Supervised_NT_xent loss, I find that there are some question maybe.
In SupCLR paper, when calculate the loss, the positive pair is (i, j), where label_i is the same as label_j, and the pair (i, i) is not regarded as positive pair, even though label_i must be the same as label_i.
However, when i use the Supervised_NT_xent loss from your code, and calculate Mask, I notice that Mask[i,i] is not zero. Therefore, the pair(i, i) will also be regarded as positive pair to calculate loss.
https://github.com/alinlab/CSI/blob/60742b60a16501350eca823fcc910ddd10f7a379/training/contrastive_loss.py#L72-L74
Maybe line 72 should be
Mask = torch.eq(labels, labels.t()).float().to(device) * (1 - eye)
I have some questions about it. May I trouble you to answer it? Looking forward to your reply!