lightly icon indicating copy to clipboard operation
lightly copied to clipboard

Include SimDINO

Open swamidass opened this issue 7 months ago • 4 comments

This paper seems to produce a large improvement on DINO v1 and v2 with a simplified approach. Would you consider supporting it?

https://robinwu218.github.io/SimDINO/

swamidass avatar May 04 '25 19:05 swamidass

Hi @swamidass! That looks like a really promising paper indeed, thanks for making us aware of it. We'll go over it internally and see what we can do about it.

liopeer avatar May 05 '25 09:05 liopeer

In my read through, seems that it may be fairly straightforward to implement DINO v1, by creating a modified version of DINOLoss.

I know there is a DINOv2 nearing completion. That might be straightforward to modify that implementation to create a SimDINOv2.

swamidass avatar May 06 '25 22:05 swamidass

It seems that the following code should be approximately correct for SimDINO. Adapted from https://github.com/RobinWu218/SimDINO/blob/main/simdino/main_dino.py

class MCRLoss(nn.Module):
    def __init__(self, ncrops, reduce_cov=0, expa_type=0, eps=0.5, coeff=1.0):
        super().__init__()
        self.ncrops = ncrops
        self.eps = eps
        self.coeff = coeff
        self.reduce_cov = reduce_cov
        self.expa_type = expa_type

    def forward(self, student_feat, teacher_feat):
        """
        Expansion Loss and Compression Loss between features of the teacher and student networks.
        """
        student_feat = student_feat.view(self.ncrops, -1, student_feat.shape[-1])
        teacher_feat = teacher_feat.view(2, -1, teacher_feat.shape[-1])
        
        comp_loss = self.calc_compression(student_feat, teacher_feat)
        if self.expa_type == 0: # only compute expansion on global views
            expa_loss = self.calc_expansion(student_feat[:len(teacher_feat)])
        elif self.expa_type == 1:
            expa_loss = self.calc_expansion((student_feat[:len(teacher_feat)]+teacher_feat)/2)
        loss = - self.coeff * comp_loss - expa_loss
        return loss, comp_loss.detach(), expa_loss.detach()
    
    def calc_compression(self, student_feat_list, teacher_feat_list):
        """
        Compute compression loss between student and teacher features.
        """
        # Convert lists of tensors to a single tensor for vectorized operations
        
        sim = F.cosine_similarity(teacher_feat_list.unsqueeze(1), student_feat_list.unsqueeze(0), dim=-1)
        sim.view(-1, sim.shape[-1])[:: (len(student_feat_list) + 1), :].fill_(0)  # Trick to fill diagonal
        
        n_loss_terms = len(teacher_feat_list)* len(student_feat_list) - min(len(teacher_feat_list), len(student_feat_list))
        # Sum the cosine similarities
        comp_loss = sim.mean(2).sum()/n_loss_terms
        # global_comp_loss = (sim[:, :len(teacher_feat_list)].mean(2).sum()).detach_().div_(len(teacher_feat_list))
        return comp_loss
    
    def calc_expansion(self, feat_list) -> torch.Tensor:
        """
        Compute expansion loss using Coding Rate estimation.
        """
        cov_list = []
        num_views = len(feat_list)
        m, p = feat_list[0].shape
        
        cov_list = [W.T.matmul(W) for W in feat_list]
        cov_list = torch.stack(cov_list)
        N=1
        if dist.is_initialized():
            N = dist.get_world_size()
            if self.reduce_cov == 1:
                cov_list = dist_nn.all_reduce(cov_list)
        scalar = p / (m * N * self.eps)
        I = torch.eye(p, device=cov_list[0].device)
        loss:torch.Tensor = 0
        for i in range(num_views):
            loss += torch.linalg.cholesky_ex(I + scalar * cov_list[i])[0].diagonal().log().sum()
        loss /= num_views
        loss *= (p+N*m)/(p*N*m) # the balancing factor gamma, you can also use the next line. This is ultimately a heuristic, so feel free to experiment.
        # loss *= ((self.eps * N * m) ** 0.5 / p)
        return loss


# use this instead of DINOLoss in lightly
dino_loss = MCRLoss(local_crops_number + 2,  # total number of crops = 2 global crops + local_crops_number, 
  0, 1, 0.5, 1)

The only think that might need modification is stacking the outputs? I'm not sure that is handled correctly here.

What do you think?

swamidass avatar May 06 '25 22:05 swamidass

Hi @swamidass! If the loss is really the only thing to change it would indeed be relatively quick to implement. Give me the weekend to go through the paper in more detail (I only had the chance to skim over it now) and compare it to the original DINO/DINOv2 myself. After that I can better estimate whether we would like to add this to lightly. :)

liopeer avatar May 07 '25 09:05 liopeer