Errors when sampling neighbors with disjoint setting
🐛 Describe the bug
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import NeighborLoader
if __name__ == '__main__':
dataset = PygNodePropPredDataset(name='ogbn-products', root='~/GNSL/on_products/dataset')
split_idx= dataset.get_idx_split()
data = dataset[0]
# transform_sampler_output
#print(data.x, data.x.shape, data.x.dtype)
loader = NeighborLoader(data, num_neighbors=[10, 10], batch_size=2, disjoint=1, input_nodes=split_idx['train'])
for mini_batch in loader:
print(mini_batch)
break
When I run python3 data.py, this program encounters exception:
Traceback (most recent call last):
File "/root/GG4RL/data.py", line 14, in
Environment
pyg-libversion: '0.3.1+pt20cu118'- PyTorch version: '2.0.1+cu118'
- OS: Ubuntu 18.04
- Python version: 3.10
- CUDA/cuDNN version: 11.8
- How you installed PyTorch and
pyg-lib(conda,pip, source): pip - Any other relevant information: installed GLIBC-2.9 and pytorch_geometric: 2.3.1
I downgrade my pyg_lib from 0.4.0 to 0.3.1 according to another issue when I encountered another similar issue. However, I encounter this one.
If this issue cannot be resolved now, could you kindly tell me how to build a dataloader, where each mini-batch contains batch_size independent ego-graphs of sampled target nodes, i.e., providing a similar result as disjoint=True?
I think one way is to pass in an transform_sampler_output implemented by myself, which would not combine the ego-graphs into one global subgraph. Could you give me an toy example to show how to achieve my purpose? Thanks!
Even thought I changed disjoint to False, it is still such an error there. It seems that pyg_lib causes the issues. torch_sparse works well for disjoint=false.
Hi, I'm having the same exact problem and would also appreciate the help! Here are my
torch==2.2.0+cu118
torch_cluster==1.6.3+pt22cu118
torch_geometric==2.4.0
torch_scatter==2.1.2+pt22cu118
torchaudio==2.2.0+cu118
torchmetrics==0.11.4
torchvision==0.17.0+cu118
pyg_lib==0.4.0
You either need torch-geometric==2.4.0 and pyg-lib==0.3.0, or torch-geometric==2.5.0 and pyg-lib==0.4.0.