pytorch_geometric
pytorch_geometric copied to clipboard
How to understand the input of `TransformerConv`?
From now on, we recommend using our discussion forum (https://github.com/rusty1s/pytorch_geometric/discussions) for general questions.
❓ Questions & Help
Hello @rusty1s , thanks for your great implementation of TransformerConv
. I find that the type of input x
is Tensor
or PairTensor
but I could not understand the meaning of x
when x
is PairTensor
. Could you give me an answer about the meaning of x
and edge_index
when x
is PairTensor
?
def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
edge_attr: OptTensor = None):
""""""
if isinstance(x, Tensor):
x: PairTensor = (x, x)
# propagate_type: (x: PairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)
if self.root_weight:
x_r = self.lin_skip(x[1])
if self.lin_beta is not None:
beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))
beta = beta.sigmoid()
out = beta * x_r + (1 - beta) * out
else:
out += x_r
return out
def message(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
query = self.lin_query(x_i).view(-1, self.heads, self.out_channels)
key = self.lin_key(x_j).view(-1, self.heads, self.out_channels)
if self.lin_edge is not None:
assert edge_attr is not None
edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,
self.out_channels)
key += edge_attr
alpha = (query * key).sum(dim=-1) / math.sqrt(self.out_channels)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
out = self.lin_value(x_j).view(-1, self.heads, self.out_channels)
if edge_attr is not None:
out += edge_attr
out *= alpha.view(-1, self.heads, 1)
return out