kmeans_pytorch
kmeans_pytorch copied to clipboard
Support for "kl_divergence"
I have implemented support for Kullback-Leibler divergence as follows. Shall I make a pull request of it?
def pairwise_kl_divergence(data1, data2, device=torch.device('cpu')):
# transfer to device
data1, data2 = data1.to(device), data2.to(device)
# N*1*M
A = data1.unsqueeze(dim=1)
# 1*N*M
B = data2.unsqueeze(dim=0)
# normalize the points
A_normalized = torch.nn.functional.log_softmax(A, dim=-1)
B_normalized = torch.nn.functional.log_softmax(B, dim=-1)
kl_div = torch.nn.functional.kl_div(A_normalized, B_normalized, reduction='none', log_target=True)
# return N*N matrix for pairwise distance
kl_div = kl_div.mean(-1)
return kl_div