rf-detr icon indicating copy to clipboard operation
rf-detr copied to clipboard

Fix DDP checkpoint loading by using model.module.load_state_dict

Open QuanTran255 opened this issue 2 weeks ago • 1 comments
trafficstars

Description

Summary

This PR fixes a common DistributedDataParallel (DDP) checkpoint loading error in multi-GPU setups by modifying the state_dict loading logic to use model.module.load_state_dict() instead of model.load_state_dict(). This ensures compatibility with checkpoints saved without the "module." prefix (e.g., from single-GPU or non-DDP runs). Additionally, it updates checkpoint saving to always strip the DDP prefix via model.module.state_dict(), making saved files portable across single- and multi-GPU environments. It also adds time.sleep(5) before checkpoint loading to ensure synchronization across distributed processes, preventing race conditions where non-rank-0 processes attempt to load before the file is fully written.

Fixed Issue

Motivation and Context

PyTorch's DDP wraps models with a "module." prefix on parameter keys for multi-GPU synchronization. However, if checkpoints are saved without this prefix (common in RF-DETR's default trainer), loading fails in DDP-wrapped models. This is a frequent pain point in distributed DETR variants (e.g., see PyTorch docs on Saving and Loading Models and community discussions like this Stack Overflow thread). The changes make RF-DETR's checkpoint handling DDP-aware without breaking single-GPU usage.

Dependencies

  • None (relies on existing PyTorch >=1.10 for DDP support; tested with torch 2.0+).

Type of change

Please delete options that are not relevant.

  • [x] Bug fix (non-breaking change which fixes an issue)
  • [ ] New feature (non-breaking change which adds functionality)
  • [ ] This change requires a documentation update

How has this change been tested, please provide a testcase or example of how you tested the change?

Tested on a multi-GPU setup (2x Tesla V100s via torchrun --nproc_per_node=2) with RF-DETR segmentation fine-tuning:

  1. Reproduce Error (Pre-Fix):

    • Train single-GPU to save a checkpoint (e.g., checkpoint_best_total.pth without prefix).
    • Run distributed eval: torchrun --nproc_per_node=2 main.py --run_test --resume checkpoint_best_total.pth.
    • Fails with RuntimeError on key mismatch (missing "module." prefixed keys).
  2. Verify Fix (Post-Merge):

    • Apply changes to main.py (load/save hooks around lines 502 and checkpoint callbacks).
    • Rerun the same distributed eval command—loads successfully, eval proceeds with metrics (e.g., [email protected]=0.75 for custom dataset).
    • Test save portability: Load the new checkpoint into single-GPU (nproc_per_node=1)—no prefix errors.
    • Edge case: Resume interrupted distributed train; barriers ensure sync.

Full test script snippet:

# In main.py
checkpoint = torch.load(path, map_location='cpu')
model.module.load_state_dict(checkpoint['model'])  # Fixed load

Ran on PyTorch 2.1.0, CUDA 12.1; no regressions in non-DDP mode.

Any specific deployment considerations

  • Usability: No API changes—users can drop in fixed checkpoints seamlessly. Recommend adding --master_port flag in docs for cluster runs to avoid port conflicts.
  • Costs/Secrets: None; reduces failed runs on HPC/multi-GPU, potentially saving compute time.
  • Backward Compat: Old checkpoints load fine (via model.module); new saves are prefix-free for broader compatibility.

Docs

  • [ ] Docs updated? What were the changes:

QuanTran255 avatar Nov 04 '25 15:11 QuanTran255

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

:white_check_mark: QuanTran255
:x: elaineryl
You have signed the CLA already but the status is still pending? Let us recheck it.

CLAassistant avatar Nov 04 '25 15:11 CLAassistant