SupContrast
SupContrast copied to clipboard
DistributedDataParallel
Is there a way to make SupContrast work for DistributedDataParallel? By default each worker can only see its own sub-batch so the inter-sub-batch relationship of the samples will be utilized.
You can use all_gather
to gather features together. The caveat is that you need to manually propagate gradients through all_gather
op, as it doesn't auto-bp.
You can use
all_gather
to gather features together. The caveat is that you need to manually propagate gradients throughall_gather
op, as it doesn't auto-bp.
I finally make it work with the help of diffdist, which provides a differentiable all_gather wrapper.
Hi, can you share your code about how to implement this? I am not familiar with all_gather
.etc operations. Thanks a lot.
Hi, can you share your code about how to implement this? I am not familiar with
all_gather
.etc operations. Thanks a lot.
First, install diffdist. Then put the following snippet before calling the criterion:
import diffdist.functional as distops
features = distops.all_gather(
gather_list=[torch.zeros_like(features) for _ in range(torch.distributed.get_world_size())],
tensor=features,
next_backprop=None,
inplace=True,
)
features = torch.cat(features)
labels = distops.all_gather(
gather_list=[torch.zeros_like(labels) for _ in range(torch.distributed.get_world_size())],
tensor=labels,
next_backprop=None,
inplace=True,
)
labels = torch.cat(labels)
Thank you for your quick reply.
So, then I can simply compute the loss as usual and then backward the gradient?
Yes, but I'm not sure this is bug-free.
On Tue, Jul 21, 2020, 5:27 PM XU Shijian [email protected] wrote:
Thank you for your quick reply.
So, then I can simply compute the loss as usual and then backward the gradient?
— You are receiving this because you authored the thread. Reply to this email directly, view it on GitHub https://github.com/HobbitLong/SupContrast/issues/26#issuecomment-661744648, or unsubscribe https://github.com/notifications/unsubscribe-auth/AEFQBDVHNFCQAFCRFJGVORDR4VNO7ANCNFSM4O6IJ54A .
OK. Anyway, thanks a lot.
Just for reference, this seems to be a reliable solution.