transtab
transtab copied to clipboard
self_supervised_contrastive_loss
def self_supervised_contrastive_loss(self, features): '''Compute the self-supervised VPCL loss.
Parameters
----------
features: torch.Tensor
the encoded features of multiple partitions of input tables, with shape ``(bs, n_partition, proj_dim)``.
Returns
-------
loss: torch.Tensor
the computed self-supervised VPCL loss.
'''
batch_size = features.shape[0]
labels = torch.arange(batch_size, dtype=torch.long, device=self.device).view(-1,1)
mask = torch.eq(labels, labels.T).float().to(labels.device)
contrast_count = features.shape[1]
# [[0,1],[2,3]] -> [0,2,1,3]
contrast_feature = torch.cat(torch.unbind(features,dim=1),dim=0)
anchor_feature = contrast_feature
anchor_count = contrast_count
anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
mask = mask.repeat(anchor_count, contrast_count)
logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(features.device), 0)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss
I have a question about the function mentioned above. When calculating the final loss, the partition of different samples is multiplied together. Through masking, each row of data is restricted to calculate the product between the jth partition of the ith sample and the partitions of the other samples, and then the final loss is obtained. In contrastive learning, the general loss function for optimization is exp(zizj)/exp(zizk), where the numerator is the similarity between the same sample's different partitions and the denominator is the similarity between different samples' partitions. The goal is to make the numerator small and the denominator large. However, in the aforementioned function, only the denominator is visible and the numerator is not present. Is this my misunderstanding or is there a problem? the loss is compute the sample i,