pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Wrapping the loss with DistributedDataParallel
Hi,
When using DDP, should the loss_fn be wrapped with DistributedDataParallel? I’m specifically working with CosFace and ArcFace, both of which have a W parameter in the loss function. To ensure the gradients of W are synchronized across all processes, is it necessary to wrap loss_fn with DistributedDataParallel?
I saw it here:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/218