pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

scatter_weighted_add

Open forestyang4321 opened this issue 3 years ago • 5 comments
trafficstars

I have a [n, d] tensor A and a [t, n] tensor weight and a [n] tensor index.

I would like to obtain a [t, m, d] tensor A_scatter such that A_scatter[i] is the result of scatter(A*(weight[i][:,None]), index, dim=0, reduce='sum'). That is, it's the scatter sum of rows of A according to index, but with the rows weighted by weight[i].

Currently in order to do this I need to construct the [t, n, d] matrix A[None, :, :,] * weight[:, :, None] and then scatter that. But the construction of this matrix is very memory inefficient and gives me out of memory errors in pytorch.

Can there be a built in scatter_weighted_add that performs a weighted sum, without needing to form the extra weighted matrix?

forestyang4321 avatar Jul 30 '22 19:07 forestyang4321

Are you referring to a matrix multiplication (since your weight seems to be 2-dimensional). In that case, you may want to take a look at https://pyg-lib.readthedocs.io/en/latest/modules/ops.html.

rusty1s avatar Aug 01 '22 13:08 rusty1s

I would really like a weighted scatter add.

That is, while scatter_add(src, index, dim) adds subtensors along dim in src according to index, I want weighted_scatter_add(src, index, weight, dim) to add subtensors weighted by weight along dim in src according to index.

The reason why my weight was 2-dimensional was that I would like to do this efficiently with different weight vectors at the same time, using the same index. But if this isn't possible, it's still a big help to be able to do weighted_scatter_add with a 1-d weight, I can just call this in serial with the different weight vectors.

forestyang4321 avatar Aug 01 '22 16:08 forestyang4321

Sorry for the late reply, but why does

scatter_add(weight * src, index, dim)

not work for you?

rusty1s avatar Aug 10 '22 11:08 rusty1s

It's just memory intensive. It actually does work for me now, so I guess there's no problem, but if src is a very large tensor to the point of band-limiting your workflow, then needing to form the additional tensor weight * src is a noticeable inefficiency. I would guess that matrix-vector multiplications and dot products are implemented under the hood not as generating another tensor via pointwise multiplication, followed by summation, in part for the same reason.

forestyang4321 avatar Aug 10 '22 16:08 forestyang4321

Ah, I get your point. Yes, one could fuse this together. This is a great optimization.

rusty1s avatar Aug 10 '22 16:08 rusty1s

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

github-actions[bot] avatar Feb 07 '23 01:02 github-actions[bot]

This should be resolved with PyTorch 2.0 and torch.compile. Closing this for now.

rusty1s avatar Feb 07 '23 06:02 rusty1s