kmeans_pytorch
kmeans_pytorch copied to clipboard
Try to solve the OOM for large scale dataset
Hi, it is amazing module. But if I try to set the cluster number big or the dataset is too large. Then I will caught OOM issues. I have refactor the code via batch script. Please feel free if it is good for you.
Best regards. Alisca
Appendix, the refactor code for euclidean distance calculation with batch step.
def pairwise_distance(data1, data2, device=torch.device('cpu'), batch_size=100000): # transfer to device data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
dis_reduce = torch.zeros([data1.shape[0], data2.shape[0]])
for batch_idx in range(int(np.ceil(data1.shape[0]/batch_size))):
dis = (A[batch_idx * batch_size: (batch_idx+1) * batch_size] - B) ** 2.0
dis = dis.sum(dim=-1).squeeze()
dis_reduce[batch_idx * batch_size: (batch_idx+1) * batch_size] = dis
return dis_reduce