kmeans_pytorch icon indicating copy to clipboard operation
kmeans_pytorch copied to clipboard

Try to solve the OOM for large scale dataset

Open AliscaChen opened this issue 4 years ago • 0 comments

Hi, it is amazing module. But if I try to set the cluster number big or the dataset is too large. Then I will caught OOM issues. I have refactor the code via batch script. Please feel free if it is good for you.

Best regards. Alisca

Appendix, the refactor code for euclidean distance calculation with batch step.

def pairwise_distance(data1, data2, device=torch.device('cpu'), batch_size=100000): # transfer to device data1, data2 = data1.to(device), data2.to(device)

# N*1*M
A = data1.unsqueeze(dim=1)

# 1*N*M
B = data2.unsqueeze(dim=0)

dis_reduce = torch.zeros([data1.shape[0], data2.shape[0]])
for batch_idx in range(int(np.ceil(data1.shape[0]/batch_size))):
    dis = (A[batch_idx * batch_size: (batch_idx+1) * batch_size] - B) ** 2.0
    dis = dis.sum(dim=-1).squeeze()
    dis_reduce[batch_idx * batch_size: (batch_idx+1) * batch_size] = dis
return dis_reduce

AliscaChen avatar Sep 24 '20 03:09 AliscaChen