pytorch_scatter icon indicating copy to clipboard operation
pytorch_scatter copied to clipboard

[torch_scatter 2.1.2] Scatter max bug with negative numbers

Open lemyx opened this issue 1 year ago • 2 comments

pip install torch_scatter

The version of torch_scatter is 2.1.2

import torch
from torch_scatter import scatter_max

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) * -1
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))

out, argmax = scatter_max(src, index, out=out)

print(out)
print(argmax)

The result of out is:

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])

lemyx avatar Jan 06 '25 03:01 lemyx

Can you paste installation log? Looks like installation may have failed.

rusty1s avatar Jan 06 '25 03:01 rusty1s

FWIS, tensor "out" is outdated.

out, argmax = scatter_max(src, index, out=out) -> bug bout, bargmax = scatter_max(src, index); -> good


import torch
from torch_scatter import scatter_max

src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) * -1
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))

out, argmax = scatter_max(src, index, out=out)


print("1111111111");
print(out)
print(argmax)



ci=index.cuda()
a=src.to(torch.float32).cuda();
aout, aargmax = scatter_max(a, ci);


print("222222222");
print(aout)
print(aargmax)

bout, bargmax = scatter_max(src, index);
print("33333333");
print(bout)
print(bargmax)



Image

lancercat avatar Feb 10 '25 17:02 lancercat

This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?

github-actions[bot] avatar Aug 10 '25 02:08 github-actions[bot]

I fount this issue as well , is it resolved?

yuanxuanS avatar Dec 03 '25 12:12 yuanxuanS