dinov2 icon indicating copy to clipboard operation
dinov2 copied to clipboard

Teacher crops not used ?

Open jeromel05 opened this issue 5 months ago • 0 comments

Hello, It has come to my attention that in the train.py file, the function "dinov2 > data > collate > collate_data_and_cast" is passed to the dataloader to collate the samples in the batch. This function takes the local and global crops from the data as follows:

def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
    n_global_crops = len(samples_list[0][0]["global_crops"])
    n_local_crops = len(samples_list[0][0]["local_crops"])

However, there is also a "global_crops_teacher" key in the samples_list list of list of dicts, which contains the two global crops for the teacher. This however is not used in the code and later in the "dinov2 > train > ssl_meta_arch > SSLMetaArch > forward_backward" function, the same global crops are fed both to the teacher and the student. As seen here:

def forward_backward(self, images, teacher_temp):
    [...]
    global_crops = images["collated_global_crops"].cuda(non_blocking=True)
    [...]
    # teacher output
    def get_teacher_output():
        x, n_global_crops_teacher = global_crops, n_global_crops
        teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True)
    [...]
    student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone(
       [global_crops, local_crops], masks=[masks, None], is_training=True
    )

My understanding of the Dino loss was that two different pairs of global crops were made, one fed to the teacher and the other to the student. Instead, here there seems to be a single pair passed to both the student and the teacher backbones. Could you please confirm which alternative is correct? Thank you.

jeromel05 avatar Jan 17 '24 10:01 jeromel05