pytorch-metric-learning icon indicating copy to clipboard operation
pytorch-metric-learning copied to clipboard

How to use distributed CrossBatchMemory in a MoCo way?

Open ZihaoH opened this issue 3 years ago • 3 comments

Hello sir, your git does really ease my burden, I really appreciate it . Thanks! But when I change my code to a DDP manner, something goes wrong. Here is my defination of loss

from pytorch_metric_learning.utils import distributed as pml_dist
loss_fn = losses.MultiSimilarityLoss()
miner = miners.MultiSimilarityMiner()
loss_fn = losses.CrossBatchMemory(loss_fn, embedding_size=128, memory_size=8000, miner=miner)
loss_fn = pml_dist.DistributedLossWrapper(loss=loss_fn)

Then I try to calculate final loss like it used to be

label_one = torch.arange(cur_start_id+bs*rank, cur_start_id+bs*(rank+1)).cuda(local_rank, non_blocking=True)
labels = torch.tile(label_one, dims=(2,))
enqueue_idx = torch.arange(bs, bs*2)
loss = loss_fn(torch.cat([q,k2],dim=0), labels, enqueue_idx=enqueue_idx)

It goes

TypeError: forward() got an unexpected keyword argument 'enqueue_idx'

Do I use it wrong? Or is it still not supported well? Is there any demo for ddp CrossBatchMemory in a MoCo way? Looking forward to your reply!

ZihaoH avatar May 17 '22 09:05 ZihaoH

Ah crap! Maybe I need to add **kwargs to the forward function of the distributed wrapper https://github.com/KevinMusgrave/pytorch-metric-learning/blob/63e4ecb781c5735ad714f61a3eecc55f72496905/src/pytorch_metric_learning/utils/distributed.py#L72-L86

KevinMusgrave avatar May 17 '22 18:05 KevinMusgrave

Hi, I got another question. When using CrossBatchMemory, given a batch of query and key, before putting the key into memory bank, if there are some embedding in the memory bank having same labels as query, will they be considered as positive sample?

ZihaoH avatar May 27 '22 03:05 ZihaoH

Yes they will be considered positive samples. So if you're using it for MoCo you have to keep incrementing your labels so that no positive pairs are formed between the current batch and the embeddings in memory.

KevinMusgrave avatar May 27 '22 07:05 KevinMusgrave

This is fixed in v1.7.3. The only caveat is that len(enqueue_idx) must be the same in all processes. In other words, the same number of embeddings have to be enqueued in each process. This will probably be fixed in v2.0, if I change enqueue_idx to enqueue_mask: https://github.com/KevinMusgrave/pytorch-metric-learning/issues/575

KevinMusgrave avatar Jan 29 '23 00:01 KevinMusgrave

enqueue_idx is now enqueue_mask starting in v2.0.0. So a different number of embeddings can be enqueued by each process.

KevinMusgrave avatar Jan 30 '23 00:01 KevinMusgrave