pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Allow efficient=True when using CrossBatchMemory in DistributedLossWrapper
When efficient=True
:
- All embeddings should be added to each rank's
CrossBatchMemory.embedding_memory
- Only the current rank's embeddings should be passed as the first argument to
CrossBatchMemory.loss.forward()