pretrain-gnns icon indicating copy to clipboard operation
pretrain-gnns copied to clipboard

Shape mismatch error when switching the model to GAT

Open Yujun-Yan opened this issue 3 years ago • 3 comments

Hi, I got shape mismatch error for this line "x_j += edge_attr" in the message function of GATConv class when I tried to switch to the GAT model. It seems that the reshaping "x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)" in the forward function mess up the shape of "x_j".

Yujun-Yan avatar Jul 06 '21 10:07 Yujun-Yan

You may change x = self.weight_linear(x).view(-1, self.heads, self.emb_dim) to x = self.weight_linear(x). And add x_i=x_i.view(-1, self.heads, self.emb_dim),x_j=x_j.view(-1, self.heads, self.emb_dim) in function def message(self, edge_index, x_i, x_j, edge_attr):. The whole code is described as follows: ` def forward(self, x, edge_index, edge_attr):

    #add self loops in the edge space
    edge_index = add_self_loops(edge_index, num_nodes = x.size(0))

    #add features corresponding to self-loop edges.
    self_loop_attr = torch.zeros(x.size(0), 2)
    self_loop_attr[:,0] = 4 #bond type for self-loop edge
    self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
    edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)

    edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])

    #x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
    x = self.weight_linear(x)
    #return self.propagate(self.aggr, edge_index[0], x=x, edge_attr=edge_embeddings)
    return self.propagate( edge_index[0], x=x, edge_attr=edge_embeddings)

def message(self, edge_index, x_i, x_j, edge_attr):
    x_i=x_i.view(-1, self.heads, self.emb_dim)
    x_j=x_j.view(-1, self.heads, self.emb_dim)
    edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
    x_j += edge_attr

    alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)

    alpha = F.leaky_relu(alpha, self.negative_slope)
    alpha = softmax(alpha, edge_index[0])

    return x_j * alpha.view(-1, self.heads, 1)`

wubo2180 avatar Sep 02 '21 07:09 wubo2180

I change the code as you described but I still get this error.

Traceback (most recent call last): File "attribute_masking.py", line 784, in main() File "attribute_masking.py", line 778, in main train_loss, train_acc_atom, train_acc_bond = train(mask_edge, model_list, loader, optimizer_list, device) File "attribute_masking.py", line 699, in train node_rep = model(batch.x, batch.edge_index, batch.edge_attr) File "/home/programs/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "attribute_masking.py", line 276, in forward h = self.gnns[layer](h_list[layer], edge_index, edge_attr) File "/home/programs/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "attribute_masking.py", line 157, in forward return self.propagate(edge_index[0], x=x, edge_attr=edge_embeddings) File "/home/programs/conda/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 344, in propagate out = self.aggregate(out, **aggr_kwargs) File "/home/programs/conda/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 428, in aggregate reduce=self.aggr) File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/scatter.py", line 152, in scatter return scatter_sum(src, index, dim, out, dim_size) File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/scatter.py", line 11, in scatter_sum index = broadcast(index, src, dim) File "/home/programs/conda/lib/python3.7/site-packages/torch_scatter/utils.py", line 12, in broadcast src = src.expand(other.size()) RuntimeError: The expanded size of the tensor (2) must match the existing size (2918) at non-singleton dimension 1. Target sizes: [2918, 2, 300]. Tensor sizes: [1, 2918, 1]

kajjana avatar Apr 20 '22 15:04 kajjana

@kajjana It seems to be the issue of self.node_dim. Originally it's set as -2, and I hack it by setting it as 0. In specific, there are two ways to handle this:

  1. Rewrite the following function:
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        if ptr is not None:
            ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
            return segment_csr(inputs, ptr, reduce=self.aggr)
        else:
            return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr)
  1. Or another simple way is to fix it by setting self.node_dim=0 in GATConv.

I have a clean version in this repo, feel free to check out.

chao1224 avatar Jul 25 '22 21:07 chao1224