SimCLR icon indicating copy to clipboard operation
SimCLR copied to clipboard

The loss doesn't decrease when using multi nodes.

Open lailvlong opened this issue 3 years ago • 10 comments

When i use one node, the code runs well. However, when I use 2 nodes and set the batch_size to 64, the loss is always around 5.545 and doesn't decrease. As 5.545 is the value of ln(512), it seems like that the network never get new knowledge during training. I have checked that the parameters are not fixed. I think maybe there is something wrong with the GatherLayer but i can not find it out. Have you met this problem?

lailvlong avatar Mar 18 '21 09:03 lailvlong

Same issue here, loss stuck on +- 6.23. Everything works fine when training on single node.

Attila94 avatar Mar 26 '21 11:03 Attila94

I have the same multi-node loss issue. Any solution for this problem?

Sdhir avatar Apr 01 '21 17:04 Sdhir

hi guys, you only have the problem with multiple nodes? I get the same issue even on a single node but multiple processes(ranks). Any suggestion?

Pexure avatar May 22 '21 12:05 Pexure

hi guys, you only have the problem with multiple nodes? I get the same issue even on a single node but multiple processes(ranks). Any suggestion?

same issue here, could you please tell me whether you figure it out with any solution?

MAGI003769 avatar Aug 04 '21 03:08 MAGI003769

Maybe a bug in class NT_Xent(nn.Module) when using multi-gpus. The mask and positive/negative pairs are wrong I think.

huangdi95 avatar Sep 02 '21 03:09 huangdi95

Maybe a bug in class NT_Xent(nn.Module) when using multi-gpus. The mask and positive/negative pairs are wrong I think.

I agree. To make the implementation work on multi node or multi processes, I think the GatherLayer should be applied to z_i and z_j independently before the concatenation(z = torch.cat((z_i, z_j), dim=0)). In my case, this slight modification solved the problem and the loss came to decrease during training.

dltkddn0525 avatar Nov 03 '21 12:11 dltkddn0525

Hey guys, I have adjusted some code of the forward function in class NT_Xent and now it can work, but I just found the multi-gpu performance is mush worse than only using one gpu, do you know the reason?

def forward(self, z_i, z_j):
        N = 2 * self.batch_size * self.world_size
        z_list_i = [torch.zeros_like(z_i) for _ in range(dist.get_world_size())]
        z_list_j = [torch.zeros_like(z_j) for _ in range(dist.get_world_size())]
        #z = F.normalize(z, p=2, dim=1)
        if self.world_size > 1:
            z_list_i = diffdist.functional.all_gather(z_list_i, z_i)
            z_list_j = diffdist.functional.all_gather(z_list_j, z_j)
            
            z_i = torch.cat(z_list_i,dim=0)
            z_j = torch.cat(z_list_j,dim=0)
        z = torch.cat((z_i, z_j), dim=0)

wooozihui avatar Dec 06 '21 09:12 wooozihui

Hey guys, I have adjusted some code of the forward function in class NT_Xent and now it can work, but I just found the multi-gpu performance is mush worse than only using one gpu, do you know the reason?

def forward(self, z_i, z_j):
        N = 2 * self.batch_size * self.world_size
        z_list_i = [torch.zeros_like(z_i) for _ in range(dist.get_world_size())]
        z_list_j = [torch.zeros_like(z_j) for _ in range(dist.get_world_size())]
        #z = F.normalize(z, p=2, dim=1)
        if self.world_size > 1:
            z_list_i = diffdist.functional.all_gather(z_list_i, z_i)
            z_list_j = diffdist.functional.all_gather(z_list_j, z_j)
            
            z_i = torch.cat(z_list_i,dim=0)
            z_j = torch.cat(z_list_j,dim=0)
        z = torch.cat((z_i, z_j), dim=0)

ok .... I think this question has been solved ... the ddp model did not replace the origin one by mistake, so it did not work well. By properly setting the training model, this function is well for the multi-gpu training in ddp.

wooozihui avatar Dec 06 '21 14:12 wooozihui

@wooozihui , can you elaborate on where/how you replaced or properly set the training model?

jwjohnson314 avatar Oct 13 '22 15:10 jwjohnson314

Hello guys, I've been plagued by the inexplicable code in nt_xent.py for a long time too. Finally I found this issue.

I agree to @dltkddn0525 's opinion. The all_gather result of z used to build mask should be something like [z_i1, z_i2, ..., z_iw, z_j1, z_j2, ..., z_jw] instead of [z_i1, z_j1,z_i2, z_j2, ..., z_iw, z_jw], where w represents the world_size.

I've made a pull request of this issue, hope this will help!

lxysl avatar Sep 02 '23 14:09 lxysl