dinov2
dinov2 copied to clipboard
Sinkhorn Knopp NaN
I was running a debug job on 2 processes. I noticed that Sinkhorn Knopp produces NaN valeus for teacher randomly(the previous loss & data in current batch all look normal). Upon further inspection, I noticed that the all_reduce function in sinkhorn_knopp doesn't wait for all other processes, though I'm not quite sure if that's the reason.
Here's the code: def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): import os print("teacher sinkhorn on RANK", os.environ["RANK"]) if os.environ["RANK"] == 0: teacher_output = torch.load("tmp0.pth") else: teacher_output = torch.load("tmp.pth") teacher_temp = torch.load("temperature.pth") teacher_output = teacher_output.float() print("starting sinkhorn iteration on RANK", os.environ["RANK"]) world_size = dist.get_world_size() if dist.is_initialized() else 1 print("world size on RANK", os.environ["RANK"], world_size) Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper B = Q.shape[1] * world_size # number of samples to assign K = Q.shape[0] # how many prototypes
# make the matrix sums to 1
sum_Q = torch.sum(Q)
if dist.is_initialized():
print("reducing sum_Q on RANK", os.environ["RANK"])
dist.all_reduce(sum_Q)
Q /= sum_Q
for it in range(n_iterations):
# normalize each row: total weight per prototype must be 1/K
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
if dist.is_initialized():
dist.all_reduce(sum_of_rows)
Q /= sum_of_rows
Q /= K
# normalize each column: total weight per sample must be 1/B
Q /= torch.sum(Q, dim=0, keepdim=True)
Q /= B
Q *= B # the columns must sum to 1 so that Q is an assignment
print("finished sinkhorn iteration on RANK", os.environ["RANK"])
return Q.t()
where the loaded .pth files are tensors saved from previous tries that produced NaN. (I tried simulating all_reduce by loading both tensors and doing sinkhorn_knopp within a single process and manually adding the sum_Q and sum_of_rows variable before normalizing, no NaN values are produced.)
Here's the printed output: forward on RANK 1 forward on RANK 0 teacher sinkhorn on RANK 1 teacher sinkhorn on RANK 0 starting sinkhorn iteration on RANK 1 world size on RANK 1 2 reducing sum_Q on RANK 1 finished sinkhorn iteration on RANK 1 starting sinkhorn iteration on RANK 0 world size on RANK 0 2 reducing sum_Q on RANK 0
I'm running this on a single node with 2 L4 GPUs(same issue arises when I run on 4 nodes each with 8 L4 GPUs)
Is there a missing call to a barrier?