gritlm
gritlm copied to clipboard
question about emb loss
def call(self, q_reps, p_reps): if self.negatives_cross_device: # This gathers both negatives and positives. # It could likely be optimized by only gathering negatives. q_reps = self._dist_gather_tensor(q_reps) p_reps = self._dist_gather_tensor(p_reps) scores = self.compute_similarity(q_reps, p_reps) / self.temperature scores = scores.view(q_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target *= (p_reps.size(0) // q_reps.size(0))
return self.cross_entropy(scores, target)
in the code,does it use ContrastiveLoss following the paper?
yes thats contrastive loss
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) target *= (p_reps.size(0) // q_reps.size(0))
why use target as this way,a little confuse
I have some sense each query refers to some samples,use divided to count the num of samples,and use arrange with multiply to find the positive item index. Maybe is that?