pytorch_scatter
pytorch_scatter copied to clipboard
scatter_weighted_add
I have a [n, d] tensor A and a [t, n] tensor weight and a [n] tensor index.
I would like to obtain a [t, m, d] tensor A_scatter such that A_scatter[i] is the result of scatter(A*(weight[i][:,None]), index, dim=0, reduce='sum'). That is, it's the scatter sum of rows of A according to index, but with the rows weighted by weight[i].
Currently in order to do this I need to construct the [t, n, d] matrix A[None, :, :,] * weight[:, :, None] and then scatter that. But the construction of this matrix is very memory inefficient and gives me out of memory errors in pytorch.
Can there be a built in scatter_weighted_add that performs a weighted sum, without needing to form the extra weighted matrix?
Are you referring to a matrix multiplication (since your weight seems to be 2-dimensional). In that case, you may want to take a look at https://pyg-lib.readthedocs.io/en/latest/modules/ops.html.
I would really like a weighted scatter add.
That is, while scatter_add(src, index, dim) adds subtensors along dim in src according to index, I want weighted_scatter_add(src, index, weight, dim) to add subtensors weighted by weight along dim in src according to index.
The reason why my weight was 2-dimensional was that I would like to do this efficiently with different weight vectors at the same time, using the same index. But if this isn't possible, it's still a big help to be able to do weighted_scatter_add with a 1-d weight, I can just call this in serial with the different weight vectors.
Sorry for the late reply, but why does
scatter_add(weight * src, index, dim)
not work for you?
It's just memory intensive. It actually does work for me now, so I guess there's no problem, but if src is a very large tensor to the point of band-limiting your workflow, then needing to form the additional tensor weight * src is a noticeable inefficiency. I would guess that matrix-vector multiplications and dot products are implemented under the hood not as generating another tensor via pointwise multiplication, followed by summation, in part for the same reason.
Ah, I get your point. Yes, one could fuse this together. This is a great optimization.
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?
This should be resolved with PyTorch 2.0 and torch.compile. Closing this for now.