kmeans_pytorch icon indicating copy to clipboard operation
kmeans_pytorch copied to clipboard

Support for "kl_divergence"

Open ControllableGeneration opened this issue 1 year ago • 0 comments

I have implemented support for Kullback-Leibler divergence as follows. Shall I make a pull request of it?

def pairwise_kl_divergence(data1, data2, device=torch.device('cpu')):
    # transfer to device
    data1, data2 = data1.to(device), data2.to(device)

    # N*1*M
    A = data1.unsqueeze(dim=1)

    # 1*N*M
    B = data2.unsqueeze(dim=0)

    # normalize the points 
    A_normalized = torch.nn.functional.log_softmax(A, dim=-1)
    B_normalized = torch.nn.functional.log_softmax(B, dim=-1)

    kl_div = torch.nn.functional.kl_div(A_normalized, B_normalized, reduction='none', log_target=True)

    # return N*N matrix for pairwise distance
    kl_div = kl_div.mean(-1)
    return kl_div

ControllableGeneration avatar Mar 21 '24 02:03 ControllableGeneration