Include SimDINO
This paper seems to produce a large improvement on DINO v1 and v2 with a simplified approach. Would you consider supporting it?
https://robinwu218.github.io/SimDINO/
Hi @swamidass! That looks like a really promising paper indeed, thanks for making us aware of it. We'll go over it internally and see what we can do about it.
In my read through, seems that it may be fairly straightforward to implement DINO v1, by creating a modified version of DINOLoss.
I know there is a DINOv2 nearing completion. That might be straightforward to modify that implementation to create a SimDINOv2.
It seems that the following code should be approximately correct for SimDINO. Adapted from https://github.com/RobinWu218/SimDINO/blob/main/simdino/main_dino.py
class MCRLoss(nn.Module):
def __init__(self, ncrops, reduce_cov=0, expa_type=0, eps=0.5, coeff=1.0):
super().__init__()
self.ncrops = ncrops
self.eps = eps
self.coeff = coeff
self.reduce_cov = reduce_cov
self.expa_type = expa_type
def forward(self, student_feat, teacher_feat):
"""
Expansion Loss and Compression Loss between features of the teacher and student networks.
"""
student_feat = student_feat.view(self.ncrops, -1, student_feat.shape[-1])
teacher_feat = teacher_feat.view(2, -1, teacher_feat.shape[-1])
comp_loss = self.calc_compression(student_feat, teacher_feat)
if self.expa_type == 0: # only compute expansion on global views
expa_loss = self.calc_expansion(student_feat[:len(teacher_feat)])
elif self.expa_type == 1:
expa_loss = self.calc_expansion((student_feat[:len(teacher_feat)]+teacher_feat)/2)
loss = - self.coeff * comp_loss - expa_loss
return loss, comp_loss.detach(), expa_loss.detach()
def calc_compression(self, student_feat_list, teacher_feat_list):
"""
Compute compression loss between student and teacher features.
"""
# Convert lists of tensors to a single tensor for vectorized operations
sim = F.cosine_similarity(teacher_feat_list.unsqueeze(1), student_feat_list.unsqueeze(0), dim=-1)
sim.view(-1, sim.shape[-1])[:: (len(student_feat_list) + 1), :].fill_(0) # Trick to fill diagonal
n_loss_terms = len(teacher_feat_list)* len(student_feat_list) - min(len(teacher_feat_list), len(student_feat_list))
# Sum the cosine similarities
comp_loss = sim.mean(2).sum()/n_loss_terms
# global_comp_loss = (sim[:, :len(teacher_feat_list)].mean(2).sum()).detach_().div_(len(teacher_feat_list))
return comp_loss
def calc_expansion(self, feat_list) -> torch.Tensor:
"""
Compute expansion loss using Coding Rate estimation.
"""
cov_list = []
num_views = len(feat_list)
m, p = feat_list[0].shape
cov_list = [W.T.matmul(W) for W in feat_list]
cov_list = torch.stack(cov_list)
N=1
if dist.is_initialized():
N = dist.get_world_size()
if self.reduce_cov == 1:
cov_list = dist_nn.all_reduce(cov_list)
scalar = p / (m * N * self.eps)
I = torch.eye(p, device=cov_list[0].device)
loss:torch.Tensor = 0
for i in range(num_views):
loss += torch.linalg.cholesky_ex(I + scalar * cov_list[i])[0].diagonal().log().sum()
loss /= num_views
loss *= (p+N*m)/(p*N*m) # the balancing factor gamma, you can also use the next line. This is ultimately a heuristic, so feel free to experiment.
# loss *= ((self.eps * N * m) ** 0.5 / p)
return loss
# use this instead of DINOLoss in lightly
dino_loss = MCRLoss(local_crops_number + 2, # total number of crops = 2 global crops + local_crops_number,
0, 1, 0.5, 1)
The only think that might need modification is stacking the outputs? I'm not sure that is handled correctly here.
What do you think?
Hi @swamidass! If the loss is really the only thing to change it would indeed be relatively quick to implement. Give me the weekend to go through the paper in more detail (I only had the chance to skim over it now) and compare it to the original DINO/DINOv2 myself. After that I can better estimate whether we would like to add this to lightly. :)