D-FINE icon indicating copy to clipboard operation
D-FINE copied to clipboard

Deadlock in reduce_dict when targets are empty on some ranks

Open PKurnikov opened this issue 8 months ago • 0 comments

Describe the bug When training a model with contrastive denoising (DN) enabled in a DistributedDataParallel (DDP) setting, a deadlock can occur at reduce_dict(loss_dict) if some ranks receive empty targets (i.e., no ground truth annotations in the batch).

This happens because:

In get_contrastive_denoising_training_group(), if max_gt_num == 0, the function returns None, skipping DN-related computation.

To Reproduce Steps to reproduce the behavior: Modify the COCO dataset preprocessing to randomly drop all annotations from a sample: In your data pipeline, inside the ConvertCocoPolysToMask class, apply this patch:

        anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
        if random.randint(0, 100) % 2 == 0:
            anno = []

Expected behavior Start training, and the model will hang at the reduce_dict(loss_dict) call when: One rank receives an image with annotations and computes DN loss.

Another rank receives an image without annotations, skips DN computation, and produces a smaller loss dict. This causes a mismatch in the keys being reduced across ranks, leading to a deadlock in torch.distributed.all_reduce().

To avoid the DDP deadlock caused by missing loss keys when targets are empty (i.e., no annotations), you can patch the forward method of the DFINECriterion class by ensuring that all expected DN-related loss keys are present in the losses dictionary before returning.

Add this block at the very end of the forward() method:

        device = outputs['pred_logits'].device
        losses = self.fill_missing_losses(losses, self.expected_keys, device)
        
        expected_keys = ['loss_bbox_dn_0', 'loss_bbox_dn_1', 'loss_bbox_dn_2', 'loss_bbox_dn_3', 'loss_bbox_dn_pre',
                          'loss_ddf_dn_0', 'loss_ddf_dn_1', 'loss_ddf_dn_2',
                          'loss_fgl_dn_0', 'loss_fgl_dn_1', 'loss_fgl_dn_2', 'loss_fgl_dn_3',
                          'loss_giou_dn_0', 'loss_giou_dn_1', 'loss_giou_dn_2', 'loss_giou_dn_3', 'loss_giou_dn_pre',
                          'loss_vfl_dn_0', 'loss_vfl_dn_1', 'loss_vfl_dn_2', 'loss_vfl_dn_3', 'loss_vfl_dn_pre']
        
        device = outputs['pred_logits'].device
        for key in expected_keys:
             if key not in losses:
                 losses[key] = torch.tensor(0.0, device=device)

PKurnikov avatar Apr 08 '25 14:04 PKurnikov