relbench
relbench copied to clipboard
Help with Key Error when Running Examples from Readme
From readme instructions, I tried the listed examples: python gnn_node.py --dataset rel-f1 --task driver-position , and also:
python gnn_node.py --dataset rel-avito --task ad-ctr
My environment has fresh installations of latest relbench (v1.1)/pyg etc.
But both runs give me key error: "Tried to collect 'num_sampled_nodes' but did not find any occurrences of it in any node and/or edge type", as below.
../lib/python3.10/site-packages/relbench/modeling/utils.py:14: FutureWarning: casting datetime64[ns] values to int64 with .astype(...) is deprecated and will raise in a future version. Use .view(...) instead.
unix_time = ser.astype("int64").values
0%| | 0/10 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/yye/relbench/examples/gnn_node.py", line 193, in <module>
train_loss = train()
File "/home/yye/relbench/examples/gnn_node.py", line 133, in train
pred = model(
File ".../python/torch/2/0/dist/lib/python3.10/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/yye/relbench/examples/model.py", line 102, in forward
batch.num_sampled_nodes_dict,
File ".../python3.10/site-packages/torch_geometric/data/hetero_data.py", line 161, in __getattr__
return self.collect(key[:-5])
File ".../python3.10/site-packages/torch_geometric/data/hetero_data.py", line 565, in collect
raise KeyError(f"Tried to collect '{key}' but did not find any "
KeyError: "Tried to collect 'num_sampled_nodes' but did not find any occurrences of it in any node and/or edge type"
I was unable to reproduce this error with a fresh installation. @weihua916 @rusty1s any ideas what might be going on?
What is your pyg version? Mine is the latest version.
Have you tried to set subgraph_type=directional for NeighborLoader? Also, the current baseline's GNN doesn't take num_sampled_nodes_dict and num_sampled_edges_dict as arguments, so I think you can safely remove it.
def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[NodeType, Tensor],
num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None,
num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None,
) -> Dict[NodeType, Tensor]:
for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}
x_dict = {key: x.relu() for key, x in x_dict.items()}
return x_dict