scatter_max returns 0s on GPU
Hi,
When I am using the scatter_max function on cpu, it works as expected. However, when moving all tensors and models to GPU, the functions stops working, return 0 for the max values and index + 1 for the argmax. I am using windows 11, and so far have not experienced any issues with using models on GPU. Do you know what causes this behavior? Thanks in advance.
I am using the following versions: torch: 2.0.1 torch_geometric: 2.3.1 torch-scatter: 2.1.1 CUDA: 11.8
Can you share some information about how you installed torch-scatter? For which dtypes does this fail?
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=...)
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
I installed torch-scatter with pip, using the following command:
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu118.html
With this example, running it on cpu (with dtype float), gives me the expected output:
(tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]], dtype=torch.float64),
tensor([[5, 5, 3, 4, 0, 1],
[1, 4, 3, 5, 5, 5]]))
However, the moment I change to GPU, the following happens:
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=float).cuda()
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]).cuda()
scatter_max(src, index, dim=-1)
Output:
(tensor([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], device='cuda:0', dtype=torch.float64),
tensor([[5, 5, 5, 5, 5, 5],
[5, 5, 5, 5, 5, 5]], device='cuda:0'))
I also tried using long and double dtype, but the same behaviour happens. I also noticed when using the max aggregation from torch geometric (torch_geometric.nn.aggr.MaxAggregation) the same behaviour happens.
Thanks for confirming. Can you try to uninstall torch-scatter and install via
pip install --verbose torch-scatter
and let me know the installation log and whether this fixes your problems?
Seems to be okay on Manjaro (torch_scatter 2.11, torch 2.0.1, linux65, nvidia-535.104.05, cuda12.2, gtx1070).
Torch scatter is installed via
pip install torch_scatter --break-system-packages --force-reinstall --no-cache-dir ~20 minutes ago
Output
>>> src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=float).cuda()
>>> index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]).cuda()
>>>
>>> scatter_max(src, index, dim=-1)
(tensor([[0., 0., 4., 3., 2., 0.],
[2., 4., 3., 0., 0., 0.]], device='cuda:0', dtype=torch.float64), tensor([[5, 5, 3, 4, 0, 1],
[1, 4, 3, 5, 5, 5]], device='cuda:0'))
>>>
Using lancercat's way to install torch-scatter seems to have fixed it. I'm not sure what caused the issue.
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved?