pytorch_scatter
pytorch_scatter copied to clipboard
[torch_scatter 2.1.2] Scatter max bug with negative numbers
pip install torch_scatter
The version of torch_scatter is 2.1.2
import torch
from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) * -1
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmax = scatter_max(src, index, out=out)
print(out)
print(argmax)
The result of out is:
tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]])
Can you paste installation log? Looks like installation may have failed.
FWIS, tensor "out" is outdated.
out, argmax = scatter_max(src, index, out=out) -> bug bout, bargmax = scatter_max(src, index); -> good
import torch
from torch_scatter import scatter_max
src = torch.Tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) * -1
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out = src.new_zeros((2, 6))
out, argmax = scatter_max(src, index, out=out)
print("1111111111");
print(out)
print(argmax)
ci=index.cuda()
a=src.to(torch.float32).cuda();
aout, aargmax = scatter_max(a, ci);
print("222222222");
print(aout)
print(aargmax)
bout, bargmax = scatter_max(src, index);
print("33333333");
print(bout)
print(bargmax)
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?
I fount this issue as well , is it resolved?