SimCLR
SimCLR copied to clipboard
- attempt to add support for n_views >=3
@sthalles @alessiamarcolini @butyuhao Hi,
as mentioned in #32 , the current implementation of info_nce_loss may not properly work if n_views > 2 due to the additional positive pairs. Herein I attempt to fix it by duplicate the negative pairs for additional positive ones, if I understand the mechanism of your current implementation correctly:
positives = similarity_matrix[labels.bool()].view(labels.shape[0] * (n_views - 1), -1)
# select only the negatives
# change: copy if n_views > 2 for other positive pairs of img
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).repeat(n_views - 1, 1)
logits = torch.cat([positives, negatives], dim=1)