MMD_AAE_PyTorch icon indicating copy to clipboard operation
MMD_AAE_PyTorch copied to clipboard

An optimized version of compute_pairwise_distances in utils.py

Open Crestina2001 opened this issue 1 year ago • 1 comments

The original function calculates pairwise distances by looping over all pairs of instances, which can be slow if the number of instances is large. We can use broadcasting and vectorization in PyTorch to avoid explicit loops:

def compute_pairwise_distances(x, y):
    """
    Computes the squared euclidean distance between two sets of vectors.
    """
    xx = (x**2).sum(dim=1, keepdim=True)
    yy = (y**2).sum(dim=1, keepdim=True)
    xy = torch.mm(x, y.t())
    dists = xx + yy.t() - 2. * xy
    return dists

In this function:

xx is a column vector where each element is the squared Euclidean norm of a vector in x. yy is a column vector where each element is the squared Euclidean norm of a vector in y. xy is a matrix where the element at the i-th row and j-th column is the dot product of the i-th vector in x and the j-th vector in y. dists is a matrix where the element at the i-th row and j-th column is the squared Euclidean distance between the i-th vector in x and the j-th vector in y. This version of the function avoids explicit Python loops and should be significantly faster than the original version if x and y contain many vectors.

provided by GPT-4

Crestina2001 avatar May 11 '23 07:05 Crestina2001

Thanks for the advice!

mousecpn avatar May 11 '23 07:05 mousecpn