SimCLR
SimCLR copied to clipboard
Info NCE loss
Hi, may I ask how you were able to calculate the info nce loss in this work? I am confused on the methodology as it is quite different from the code of the authors.
You are returning labels of all 0 because you only want to calculate negative labels. However in this code here, you used the logits for both the negative samples and the positive sample (I'm assuming this is the augmented counterpart of the image). May I ask the reasoning for this kind of implementation?
https://github.com/sthalles/SimCLR/blob/1848fc934ad844ae630e6c452300433fe99acfd9/simclr.py#L51-L55
P.S.: I am still at loss currently on how you were able to simplify the code to just calculating only the negative samples. Hopefully this can be clarified in your reply. Thank you!
Same question! I think the code should be revised to calculate both negative & positive samples.
I think I sort of get it now. I think the zero array labels is meant to indicate the positive label for one pass. Since at Line 51, they put the positive features in the front of the tensor, this is meant to be the positive label which is going to be index 0 which are the labels of all members of the batch.
random_labels = torch.randint(low=0, high=logits.shape[1], size=(logits.shape[0],1)).to(device)
index = torch.arange(logits.shape[0]).to(device).unsqueeze(1)
labels_access = torch.cat([index, random_labels], 1)
labels_access = torch.transpose(labels_access, 0, 1)
temp = logits[tuple(labels_access)]
logits[:,0] = temp
logits[tuple(labels_access)] = positives.squeeze()`
logits = logits/temperature
return logits, random_labels.squeeze()
This is a possible solution to randomly place the positive labels so that the target output for the network will not always be 0. Though it seems to work in their case and some other experiments I made so I guess the implemented one is fine.
https://github.com/sthalles/SimCLR/issues/16 a nice explanation. In a nutshell, the first colume is the logits of positive sample, so assign labels to all zero.
Basically, the implementation puts the first column in the logits as the positive instances ( that's why the label 0 for all cases).
Denote N = batch * n_views, and for example 64. Logits are [64, 63] tensor, and labels are [64] tensor. Note the dimension.
The logits are 64 samples' similarity to the other 63 samples and the positive ones are in the first column, which means the right class is 0. Also, you can place the positive ones at the 63rd column, and the right class is therefore 62.
The labels are all 0 because all 64 samples have the same positive pair in the first column. Also, you can place the positive ones at the 63rd column, and the right labels are therefore all 62.
I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.
torch.nn.CrossEntryLoss() = LogSoftmax + NLLLoss, you should see details of NLLLoss.
I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.
If I want to amend the code with 3 or 4-views, how I can amend the code.Please give me some tips if you are free. Thanks in advance!
I know why the labels are always zero. But, if the n-views is not 2, but 3 or 4 for example, then the positive samples are not always in colume 0. (0, 1) for the case that n-views equals to 3, and (0, 1, 2) for the case that n-views equals to 4.
If I want to amend the code with 3 or 4-views, how I can amend the code.Please give me some tips if you are free. Thanks in advance!
My solution is simple. I didn't modify the info_nce_loss, but several ce-loss calculations. Since the CrossEntropyLoss is based on NLLLoss, the labels with all the same values means a column in the logit.
For example, if the n-views is 4, then the first 3 columns in the logits are positives, so there are 3 ce-loss calculations, labels with all 0 means the first column, labels with all 1 means the second column, and labels with all 2 means the third column.
assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
can be found in run.py, this works for 2_views, if we wanna try more, we can do some self-modifications absolutely.