targeted-supcon icon indicating copy to clipboard operation
targeted-supcon copied to clipboard

Seems the code is inconsistent with the paper

Open liluhu0 opened this issue 1 year ago • 7 comments

First of all, thanks for such an excellent work! And then I have a doubt about your code that seems to be inconsistent with the paper.

Your code

loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1)) / mask_pos_view.shape[0]

But according to the formula in the paper it seems to be

loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1)) / mask_pos_view.shape[0] That is, the target loss in the paper does not seem to be divided by (k+1), which will affect the choice of tw. Paper shows the optimal tw=0.2, so is this tw calculated according to the formula in paper or in the code? Look forward to your answer, thank you!

liluhu0 avatar Jul 31 '23 02:07 liluhu0

Hi, thanks for pointing this out! I just double-checked the paper -- it seems that the paper misses a parenthesis for the two contrastive losses. Please follow the code for this, and the optimal tw is computed using the code.

LTH14 avatar Jul 31 '23 03:07 LTH14

Okay, thanks for your response!

liluhu0 avatar Jul 31 '23 07:07 liluhu0

Hi, I have a question. Doesn't (k+1) appear in mask_pos_view.shape[0] in the code? I think the correct code is: loss_target = - torch.sum((mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1))

suminRoh avatar Aug 03 '23 02:08 suminRoh

That's true, but you still need to divide by the batch size, which is mask_pos_view.shape[0]

LTH14 avatar Aug 03 '23 02:08 LTH14

Why do I have to divide by the batch size? In main_moco_supcon_imaba.py, does not the AverageMeter of losses compute the average loss?

Also, if I have to divide by the batch size, which is mask_pos_view.shape[0], then doesn't loss_class have to be divided by mask_pos_view.shape[0] twice, because of the batch size and (k+1)?

suminRoh avatar Aug 03 '23 02:08 suminRoh

Basically, the loss for each data point is (mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1). Then we sum all and divide the batch size, which is equivalent to computing the average loss: torch.sum((mask_pos_view_target * log_prob).sum(1) / mask_pos_view.sum(1)) / mask_pos_view.shape[0

LTH14 avatar Aug 03 '23 03:08 LTH14

I understand. Thank you for explaining in detail !

suminRoh avatar Aug 03 '23 05:08 suminRoh