geometric-gnn-dojo
geometric-gnn-dojo copied to clipboard
Ask for clarification on comments in the aggregate function
First of all, thank you for this wonderful resource. Could you please clarify the following part of the comment in the aggregate function?
inputs: (e, d) - messages `m_ij` from destination to source nodes
If I understood the logic, I would expect that the defined messages would go from sources (j nodes) to destination (i nodes). Could you elaborate more this point?
Here is the code of the aggregate function:
def aggregate(self, inputs, index):
"""Step (2) Aggregate
The `aggregate` function aggregates the messages from neighboring nodes,
according to the chosen aggregation function ('sum' by default).
Args:
inputs: (e, d) - messages `m_ij` from destination to source nodes
index: (e, 1) - list of source nodes for each edge/message in `input`
Returns:
aggr_out: (n, d) - aggregated messages `m_i`
"""
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
Thank you!