pytorch_sparse
pytorch_sparse copied to clipboard
how to implement cdist
Hi there, thanks a million for this library.
I am trying to figure out how compute distances between every row(or column) of two matrices e.g. like torch.cdist.
Here is one way:
def sparse_cdist(a: SparseTensor, b: SparseTensor):
a_repeated = cat([a] * a.size(0), dim=0)
b_repeated = cat(
[cat([b[i, :]] * b.size(0), dim=0) for i in range(b.size(0))], dim=0
)
distances = sparse_distance(a_repeated, b_repeated)
distances.requires_grad = False
return distances.view((a.size(0), b.size(0)))
and another:
def sparse_cdist2(a: SparseTensor, b: SparseTensor):
with torch.no_grad():
distances = torch.zeros(
(a.size(0), b.size(0)),
device=a.device(),
dtype=a.dtype(),
requires_grad=False,
)
counter = 0
for i in range(a.size(0)):
for j in range(b.size(0)):
distances[i, j] = sparse_distance(a[i, :], b[j, :])
return distances
and another:
def sparse_cdist3(a: SparseTensor, b: SparseTensor):
with torch.no_grad():
distances = torch.zeros(
(a.size(0), b.size(0)),
device=a.device(),
dtype=a.dtype(),
requires_grad=False,
)
if a.size(0) <= b.size(0):
for i in range(a.size(0)):
idx = torch.tensor([i])
idx = idx.expand(b.size(0))
a_repeated = a.index_select(0,idx)
distances[i, :] = sparse_distance(a_repeated, b)
else:
for j in range(b.size(0)):
idx = torch.tensor([j])
idx = idx.expand(a.size(0))
b_repeated = b.index_select(0,idx)
distances[:, j] = sparse_distance(a, b_repeated)
return distances
where sparse_distance is defined as follows
def sparse_distance(a: SparseTensor, b: SparseTensor):
c = a + b.mul_nnz(torch.tensor(-1).to(device), "coo")
c = c * c
c = reduction(c, dim=1, reduce="sum")
return torch.sqrt(c + 0.000000001)
The first one has the disadvantage that it creates huge matrices and eats a lot of memory, while the second one doesn't benefit from gpu parallelism. The third seems to be okish but maybe there is an even better solution? Ideally one would only load matrices a and b onto the gpu and only reserve additional memory for the result. Maybe somebody has an idea how to do that with what is currently possible with torch sparse? Or would it be necessary to write a specific cuda kernel for that? Any suggestions are very welcome.