pretrain-gnns
pretrain-gnns copied to clipboard
Shape mismatch error when switching the model to GAT
Hi, I got shape mismatch error for this line "x_j += edge_attr" in the message function of GATConv class when I tried to switch to the GAT model. It seems that the reshaping "x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)" in the forward function mess up the shape of "x_j".
You may change x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
to x = self.weight_linear(x)
. And add x_i=x_i.view(-1, self.heads, self.emb_dim)
,x_j=x_j.view(-1, self.heads, self.emb_dim)
in function def message(self, edge_index, x_i, x_j, edge_attr):
. The whole code is described as follows:
` def forward(self, x, edge_index, edge_attr):
#add self loops in the edge space
edge_index = add_self_loops(edge_index, num_nodes = x.size(0))
#add features corresponding to self-loop edges.
self_loop_attr = torch.zeros(x.size(0), 2)
self_loop_attr[:,0] = 4 #bond type for self-loop edge
self_loop_attr = self_loop_attr.to(edge_attr.device).to(edge_attr.dtype)
edge_attr = torch.cat((edge_attr, self_loop_attr), dim = 0)
edge_embeddings = self.edge_embedding1(edge_attr[:,0]) + self.edge_embedding2(edge_attr[:,1])
#x = self.weight_linear(x).view(-1, self.heads, self.emb_dim)
x = self.weight_linear(x)
#return self.propagate(self.aggr, edge_index[0], x=x, edge_attr=edge_embeddings)
return self.propagate( edge_index[0], x=x, edge_attr=edge_embeddings)
def message(self, edge_index, x_i, x_j, edge_attr):
x_i=x_i.view(-1, self.heads, self.emb_dim)
x_j=x_j.view(-1, self.heads, self.emb_dim)
edge_attr = edge_attr.view(-1, self.heads, self.emb_dim)
x_j += edge_attr
alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)
alpha = F.leaky_relu(alpha, self.negative_slope)
alpha = softmax(alpha, edge_index[0])
return x_j * alpha.view(-1, self.heads, 1)`
I change the code as you described but I still get this error.
Traceback (most recent call last):
File "attribute_masking.py", line 784, in
@kajjana
It seems to be the issue of self.node_dim
. Originally it's set as -2, and I hack it by setting it as 0. In specific, there are two ways to handle this:
- Rewrite the following function:
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=0, dim_size=dim_size, reduce=self.aggr)
- Or another simple way is to fix it by setting
self.node_dim=0
in GATConv.
I have a clean version in this repo, feel free to check out.