pytorch-metric-learning
pytorch-metric-learning copied to clipboard
How to use distributed CrossBatchMemory in a MoCo way?
Hello sir, your git does really ease my burden, I really appreciate it . Thanks! But when I change my code to a DDP manner, something goes wrong. Here is my defination of loss
from pytorch_metric_learning.utils import distributed as pml_dist
loss_fn = losses.MultiSimilarityLoss()
miner = miners.MultiSimilarityMiner()
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=128, memory_size=8000, miner=miner)
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn)
Then I try to calculate final loss like it used to be
label_one = torch.arange(cur_start_id+bs*rank, cur_start_id+bs*(rank+1)).cuda(local_rank, non_blocking=True)
labels = torch.tile(label_one, dims=(2,))
enqueue_idx = torch.arange(bs, bs*2)
loss = loss_fn(torch.cat([q,k2],dim=0), labels, enqueue_idx=enqueue_idx)
It goes
TypeError: forward() got an unexpected keyword argument 'enqueue_idx'
Do I use it wrong? Or is it still not supported well? Is there any demo for ddp CrossBatchMemory in a MoCo way? Looking forward to your reply!
Ah crap! Maybe I need to add **kwargs to the forward function of the distributed wrapper https://github.com/KevinMusgrave/pytorch-metric-learning/blob/63e4ecb781c5735ad714f61a3eecc55f72496905/src/pytorch_metric_learning/utils/distributed.py#L72-L86
Hi, I got another question. When using CrossBatchMemory, given a batch of query and key, before putting the key into memory bank, if there are some embedding in the memory bank having same labels as query, will they be considered as positive sample?
Yes they will be considered positive samples. So if you're using it for MoCo you have to keep incrementing your labels so that no positive pairs are formed between the current batch and the embeddings in memory.
This is fixed in v1.7.3. The only caveat is that len(enqueue_idx) must be the same in all processes. In other words, the same number of embeddings have to be enqueued in each process. This will probably be fixed in v2.0, if I change enqueue_idx to enqueue_mask: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/575
enqueue_idx is now enqueue_mask starting in v2.0.0. So a different number of embeddings can be enqueued by each process.