Shift-Robust-GNNs icon indicating copy to clipboard operation
Shift-Robust-GNNs copied to clipboard

Errors when running GAT and GraphSAGE

Open simonzhang00 opened this issue 2 years ago • 1 comments

Hi,

I get the following errors when running main_gnn.py:

  1. GAT:
Using backend: pytorch
number of classes 7
Using CUDA
Traceback (most recent call last):
  File "main_gnn.py", line 729, in <module>
    micro_f1, macro_f1, out_acc = main(args, [])
  File "main_gnn.py", line 463, in main
    args.aggregator_type
  File ".../Shift-Robust-GNNs/dgl_models.py", line 254, in __init__
    self.layers.append(GATConv(in_feats, n_hidden, num_heads=num_heads, feat_drop=dropout, activation=activation))
  File ".../python3.7/site-packages/dgl/nn/pytorch/conv/gatconv.py", line 160, in __init__
    self._in_src_feats, out_feats * num_heads, bias=False)
  File ".../python3.7/site-packages/torch/nn/modules/linear.py", line 81, in __init__
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

and 2. GraphSAGE:

Using backend: pytorch
number of classes 7
Using CUDA
Traceback (most recent call last):
  File "main_gnn.py", line 729, in <module>
    micro_f1, macro_f1, out_acc = main(args, [])
  File "main_gnn.py", line 544, in main
    total_loss = loss + 1 * cmd(model.h[idx_train, :], model.h[iid_train, :])
  File ".../python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'GraphSAGE' object has no attribute 'h'

relevant packages form pip freeze: dgl-cu102==0.6.1 torch==1.9.0+cu102

Thanks.

simonzhang00 avatar Jun 01 '22 16:06 simonzhang00

Hi,

I've fixed the GraphSAGE issue. I will look into the first one, it seems it's the DGL library change. I will test and get back to you later.

GentleZhu avatar Jun 03 '22 15:06 GentleZhu