pytorch_sparse icon indicating copy to clipboard operation
pytorch_sparse copied to clipboard

Problem when passing a SparseTensor to PyG GCNconv

Open yiming421 opened this issue 1 year ago • 0 comments
trafficstars

The problem occured when I try to pass a SparseTensor to PyG GCNconv. I'm working with python 3.10, cuda 12.1, torch 2.2.0, PyG 2.5.2 and torch_sparse 0.6.18 installed by conda on a ubuntu server, then things didn't work well. No matter how I change the way to create the SparseTensor object, the problem just persists. I'm wondering whether the problem comes from some version compatibility issues or there's something wrong in my environment setting(very simple because I just installed torch pyg and torch_sparse). Does anyone meet similar problem or get some idea on why this issue takes place? I think you can reproduce the issue by running following code: def test(): ei = torch.tensor([[2, 3, 4], [1, 2, 3]]).cuda(0) sp = SparseTensor.from_edge_index(ei, sparse_sizes=(5, 5)) model = GCNConv(2, 2).cuda(0) x = torch.tensor([[1, 1], [1, 1], [1, 1], [1, 1], [1, 1]]).float().cuda(0) print(x, sp) model(x, sp) print('success') test() Here is the error message: Traceback (most recent call last): File "/.../debug.py", line 15, in test() File "/.../debug.py", line 13, in test model(x, sp) File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 252, in forward edge_index = gcn_norm( # yapf: disable File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 64, in gcn_norm adj_t = torch_sparse.fill_diag(adj_t, fill_value) File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 92, in fill_diag return set_diag(src, value.new_full(sizes, fill_value), k) File "/.../miniconda3/envs/wl/lib/python3.10/site-packages/torch_sparse/diag.py", line 49, in set_diag new_row[mask] = row RuntimeError: shape mismatch: value tensor of shape [3] cannot be broadcast to indexing result of shape [0]

yiming421 avatar Sep 08 '24 08:09 yiming421