GPNN icon indicating copy to clipboard operation
GPNN copied to clipboard

Recommendation: Speeding up non-faiss distance matrix calculatios

Open ariel415el opened this issue 4 years ago • 1 comments

Hi,

Apparently the simple way of computing L2 distance in Pytorch : torch.sum((x[:, None] - y[None, :]) **2, -1) is very inefficient.

replaceing the call https://github.com/iyttor/GPNN/blob/2d7994b676a0668957cbc2f9abb281b53bfe6361/model/gpnn.py#L115

with torch.cdist(queries.view(len(queries), -1), keys.view(len(queries), -1))**2

Can run almost x10 faster and also be more memory efficient for large images

If cdist runs out of memory use this:

def efficient_compute_distances(x, y):
    dist = (x * x).sum(1)[:, None] + (y * y).sum(1)[None, :] - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
    return dist

def compute_distances_batch(queries, keys, b):
    queries = queries.reshape(queries.size(0), -1)
    keys = keys.reshape(keys.size(0), -1)
    dist_mat = torch.zeros((queries.shape[0], keys.shape[0]), dtype=torch.float16, device=device)
    n_batches = len(queries) // b
    for i in range(n_batches):
        dist_mat[i * b:(i + 1) * b] = efficient_compute_distances(queries[i * b:(i + 1) * b], keys) ** 2
    if len(queries) % b != 0:
        dist_mat[n_batches * b:] = efficient_compute_distances(queries[n_batches * b:], keys) ** 2

    return dist_mat

Both way lead to some numerical instabilities

see https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065

ariel415el avatar Nov 04 '21 08:11 ariel415el

Is there a similar way to batch out the combine_patches call?

JDvorak avatar Nov 27 '21 00:11 JDvorak