functorch
functorch copied to clipboard
Performance drop for the batching rule for aten::_sparse_mm
Hello @zou3519 , @samdow.
TLDR: I got the following error UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation.
I was trying to use vmap for batching sparse-dense matrix multiplications :
from functorch import vmap
A = tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1],
[0, 1, 2, 3, 0, 1, 2, 3],
[0, 1, 2, 3, 0, 1, 2, 3]]),
values=tensor([1., 1., 1., 1., 2., 2., 2., 2.]),
size=(2, 4, 4), nnz=8, layout=torch.sparse_coo)
X = tensor([[-0.0533, -1.3950, -0.2621],
[-1.0800, 0.3210, 0.7954],
[ 0.7737, 0.3655, 0.5691],
[-0.3505, -1.0423, -2.0650]])
bspmm = vmap(torch.sparse.mm, in_dims=(0, None))
Z = bspmm(A,X)
In [1]: A.shape
Out[1]: torch.Size([2, 4, 4])
In [2]: X.shape
Out[2]: torch.Size([4, 3])
In [3]: Z.shape
Out[3]: torch.Size([2, 4, 3])
which yields correct results but:
.../functorch/_src/vmap.py:489: UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::_sparse_mm. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/functorch/BatchedFallback.cpp:84.)
Are there plans to implement batching for this operation in the near future ?
Thanks
We currently do not support vmap over sparse tensors. Could you tell us a bit more about your use case please? (cc @cpuhrsch)
TLDR: Filtering on topological spaces (i.e. graphs, simplicial/cell complexes) require K sparse-dense matmul ops (i.e. sum_k S^k • X • W_k ) where S^k is sparse.
My specific use case is to implement a differentiable filtering operation on topological spaces (take graphs as example, higher-order relational structure in the general case). By looking around it seems that the only way to do this is using a for loop like this:
out = torch.stack([Sk.mm(X).mm(W[k]) for k, Sk in enumerate(G.adj)], dim=0).sum(axis=0)
However, with vmap the for loop above is obsolete:
from functorch import vmap
mm = vmap(torch.sparse.mm, in_dims=(0, None))
comp = mm(S, X)
out = torch.bmm(comp, W).sum(axis=0)
Where S is a KxNxN sparse tensor, X is a NxFin dense matrix and W is a KxFinxFout dense tensor.
Maybe is too specific to my use case but I think it can be very useful for all the folks that are interested in machine learning on graphs.
We would also be interested in a performant vmap for sparse-dense matrix-vector multiplications.
We use the sparse matrix to represent different interpolations in MR data, for examplecin non-uniform FFTs or Volume-to-slice Projections. In both cases, it is much more convenient (and faster) to construct the sparse matrix once and use on matmull compared to other python-only approaches.