pytorch-metric-learning
pytorch-metric-learning copied to clipboard
Bug for distributed wrapper regarding to cross batch memory loss
First of all, I really appreciated this repo. Thank you very much for the contribution! However, there are 2 functions will not work logically, in distributed.py for the loss and miner wrappers: gather_emb_and_ref and gather_enqueue_mask.
Let's take gather_enqueue_mask for example:
def gather_enqueue_mask(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)
def all_gather(x):
world_size = torch.distributed.get_world_size()
if world_size > 1:
rank = torch.distributed.get_rank()
x_list = [torch.ones_like(x) for _ in range(world_size)]
torch.distributed.all_gather(x_list, x.contiguous())
# remove curr rank
x_list.pop(rank)
return torch.cat(x_list, dim=0)
return None
the all_gather function poped the rank, which will be different int on different GPUs, then torch cat the current enqueue_mask. Then the order Of the all gathered masks will not be guaranteed the same. When using cross batch memory losses, the embedding_memory will end up different on different GPUs, which I have already confirmed running some testing function.
Here I propose 2 changes to fix this issue:
def gather(emb, labels):
device = emb.device
if labels is not None:
labels = c_f.to_device(labels, device=device)
# Gather the embeddings from every replica.
emb = c_f.to_device(emb, device=device)
emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(emb_list, emb)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
emb_list[torch.distributed.get_rank()] = emb
all_emb = torch.cat(emb_list, dim=0)
# Gather the labels from every replica.
if labels is not None:
labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(labels_list, labels)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
labels_list[torch.distributed.get_rank()] = labels
all_labels = torch.cat(labels_list, dim=0)
else:
all_labels = None
return all_emb, all_labels, labels
and
def gather_enqueue_mask(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
# Gather the enqueue_mask from every replica.
enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask
return torch.cat(enqueue_mask_list, dim=0)
Thanks for the code and explanation @zhaoyuac09!. I've found the distributed stuff to be quite tricky.
I'm really busy for the next few days, so I'll have to look at your code a bit later.
In the meantime, if you'd like, you can open a pull request with your code changes.
Thank you @KevinMusgrave. I would be happy to create a pull request later after I finish more testing cases here. If later I have succeeded all testing cases, I will wrap up all changes and open a pull request.
Another issue is, when cross batch memory loss is wrapped with the distributed wrapper, miner cannot be wrapped again since miner will already have access of all embs, labels, etc after all gathering in the loss wrapper (https://github.com/KevinMusgrave/pytorch-metric-learning/blob/c38c07c0587bad7c463ae98293d2978e931f0ae6/src/pytorch_metric_learning/utils/distributed.py#L155).
I believe your repo is really nice and almost there for distributed training support. Thanks for the nice repo and let's make it even better.
I am facing the same issue. @KevinMusgrave have you reviewed @zhaoyuac09 PR?
@lolongcovas It's not passing the existing test. See my comment: https://github.com/KevinMusgrave/pytorch-metric-learning/pull/642#issuecomment-1849537159
Here's the test file: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/tests/utils/test_distributed.py