transtab icon indicating copy to clipboard operation
transtab copied to clipboard

self_supervised_contrastive_loss

Open timtian12 opened this issue 1 year ago • 0 comments

def self_supervised_contrastive_loss(self, features): '''Compute the self-supervised VPCL loss.

    Parameters
    ----------
    features: torch.Tensor
        the encoded features of multiple partitions of input tables, with shape ``(bs, n_partition, proj_dim)``.

    Returns
    -------
    loss: torch.Tensor
        the computed self-supervised VPCL loss.
    '''
    batch_size = features.shape[0]
    labels = torch.arange(batch_size, dtype=torch.long, device=self.device).view(-1,1)
    mask = torch.eq(labels, labels.T).float().to(labels.device)
    contrast_count = features.shape[1]
    # [[0,1],[2,3]] -> [0,2,1,3]
    contrast_feature = torch.cat(torch.unbind(features,dim=1),dim=0)
    anchor_feature = contrast_feature
    anchor_count = contrast_count
    anchor_dot_contrast = torch.div(torch.matmul(anchor_feature, contrast_feature.T), self.temperature)
    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()
    
    mask = mask.repeat(anchor_count, contrast_count)
    logits_mask = torch.scatter(torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(features.device), 0)
    mask = mask * logits_mask
    # compute log_prob
    exp_logits = torch.exp(logits) * logits_mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
    # compute mean of log-likelihood over positive
    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
    loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
    loss = loss.view(anchor_count, batch_size).mean()
    return loss

I have a question about the function mentioned above. When calculating the final loss, the partition of different samples is multiplied together. Through masking, each row of data is restricted to calculate the product between the jth partition of the ith sample and the partitions of the other samples, and then the final loss is obtained. In contrastive learning, the general loss function for optimization is exp(zizj)/exp(zizk), where the numerator is the similarity between the same sample's different partitions and the denominator is the similarity between different samples' partitions. The goal is to make the numerator small and the denominator large. However, in the aforementioned function, only the denominator is visible and the numerator is not present. Is this my misunderstanding or is there a problem? the loss is compute the sample i,

timtian12 avatar Nov 21 '23 09:11 timtian12