pytorch_geometric
pytorch_geometric copied to clipboard
Fix: Prevent DDP Crash by Enabling find_unused_parameters=True in GITMol Training
This PR fixes a DDP runtime error triggered on some machines during multi-GPU training of the GITMol model, caused by conditionally unused parameters in the model's forward pass.
🧠 Root Cause The following call:
accelerator = Accelerator()
internally uses DistributedDataParallel(find_unused_parameters=False) by default. This means:
- All model parameters must contribute to the loss every iteration.
- If any are unused (e.g., due to conditional logic), the process crashes with:
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error
indicates that your module has parameters that were not used in producing loss. You can enable unused
parameter detection by passing the keyword argument find_unused_parameters=True to
torch.nn.parallel.DistributedDataParallel, and by making sure all forward function outputs participate in calculating loss.
The GITMol model is a multi-modal architecture that processes molecular graphs, SMILES strings, images, and captions. Due to this design, not all model parameters are guaranteed to participate in the forward pass for every batch (e.g., if an input modality is missing or conditionally ignored). This causes DDP to crash unless configured to tolerate such behavior.
📈 Impact
- Fixes the DDP crash on systems where conditional execution leads to parameter dropout.
- Improves robustness across varied hardware and dataset configurations.
- Safe for all training runs — introduces a small overhead but avoids hard failures.
🧪 Tested On GB200 & B200 nodes across multiple runs.
Training runs no longer crash under multi-GPU execution.