FactorCL
FactorCL copied to clipboard
About conditional loss in FactorCL-SSL case
Thank you for your great work! I have a question about a conditional loss in FactorCL-SSL case. In IRFL_model.py Line 311, conditional CLUB loss is computed as follows:
self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_embed)], dim=1))
However, I think that "embeds" should be concatenated with "aug_embeds" following Eq(8) in the paper, like:
self.club_x1x2_cond(torch.cat([self.linears_club_x1x2_cond[0](x1_embed),
self.linears_club_x1x2_cond[0](x1_aug_embed)], dim=1),
torch.cat([self.linears_club_x1x2_cond[1](x2_embed),
self.linears_club_x1x2_cond[1](x2_aug_embed)], dim=1))
Since I'm a beginner in this field, I might have misunderstood something. Is there a chance I might have misunderstood something? Your response would be really helpful for me! Thank you.