GPNN
GPNN copied to clipboard
Recommendation: Speeding up non-faiss distance matrix calculatios
Hi,
Apparently the simple way of computing L2 distance in Pytorch :
torch.sum((x[:, None] - y[None, :]) **2, -1)
is very inefficient.
replaceing the call https://github.com/iyttor/GPNN/blob/2d7994b676a0668957cbc2f9abb281b53bfe6361/model/gpnn.py#L115
with torch.cdist(queries.view(len(queries), -1), keys.view(len(queries), -1))**2
Can run almost x10 faster and also be more memory efficient for large images
If cdist runs out of memory use this:
def efficient_compute_distances(x, y):
dist = (x * x).sum(1)[:, None] + (y * y).sum(1)[None, :] - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
return dist
def compute_distances_batch(queries, keys, b):
queries = queries.reshape(queries.size(0), -1)
keys = keys.reshape(keys.size(0), -1)
dist_mat = torch.zeros((queries.shape[0], keys.shape[0]), dtype=torch.float16, device=device)
n_batches = len(queries) // b
for i in range(n_batches):
dist_mat[i * b:(i + 1) * b] = efficient_compute_distances(queries[i * b:(i + 1) * b], keys) ** 2
if len(queries) % b != 0:
dist_mat[n_batches * b:] = efficient_compute_distances(queries[n_batches * b:], keys) ** 2
return dist_mat
Both way lead to some numerical instabilities
see https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065
Is there a similar way to batch out the combine_patches call?