[Roadmap] `torch_geometric.nn.aggr` 🚀
🚀 The feature, motivation and pitch
The goal of this roadmap is to unify the concepts of aggregation inside GNNs across both MessagePassing and global readouts. Currently, these concepts are separated, e.g., via MessagePassing.aggr = "mean" and global_mean_pool(...) while the underlying implementation is the same. In addition, some aggregations are only available as global pooling operators (global_sort_pool, Set2Set, ...), while, in theory, they are also applicable during MessagePassing (and vice versa, e.g., SAGEConv.aggr = "lstm"). One additional feature is the combination of aggregations, which is a useful feature both in MessagePassing (PNAConv, EGConv, ...) and global readouts.
As such, we want to provide re-usable aggregations as part of a newly defined torch_geometric.nn.aggr.* package. Unifying these concepts also helps us to perform optimization and specialized implementations in a single place (e.g., fused kernels for multiple aggregations). After integration, the following functionality is applicable:
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr="mean")
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=LSTMAggr(channels=...))
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr=MultiAggr("mean", "max", Set2Set(channels=...))
Roadmap
The general roadmap looks as follows (at best, each implemented in a separate PR):
- [x] Define
torch_geometric.nn.aggr.*and implement aBaseAggrabstract class (#4687) - Add new aggregators:
- [x] Allow for multiple aggregations as part of a
MultiAggrclass (#4749) - [x] Add support for
class-resolver, similar to here (#4749, #4716) - [x] Ensure
torch.jit.scriptsupport (#4779) - [x] Integrate new aggregations into
MessagePassinginterface (#4779) - Move aggregators from
torch_geometric.nn.globtotorch_geometric.nn.aggr(respecting the new interface), deprecate old implementation:- [x]
MeanAggr,SumAggr,MaxAggr,MinAggr,MulAggr,VarAggr,StdAggr(#4687, #4749) - [x]
MedianAggr(#5098) - [x]
AttentionalAggr(#4986) - [x]
Set2Set(#4762) - [x]
GlobalSortAggr(#4957) - [x]
GraphMultiSetTransformer(#4973) - [x]
EquilibriumAggr(#4522) - [x] Deprecate
torch_geometric.nn.glob(#5039)
- [x]
- Update existing GNN layers to make use of new interface, e.g.:
- [x]
SAGEConv(#4863) - [x]
PNAConv(#4864) - [x]
GravNetConv(#4865) - [x]
GENConv(#4866)
- [x]
- [x] Add support for "reverse" aggregation resolver to keep
message_and_aggregatefunctionality intact (#5084) - [x] Support for multiple aggregations in
SAGEConv(#5033) - [x]
MultiAggregation: Support forconcat,concat+transform,sum,mean,max,attention(#5000, #5034) - [x] Add
semi_gradfunctionality toSoftmaxAggregation(#4995) - [x] Add and verify
torch_geometric.nn.aggr.*documentation (#5036, #5097, #5099, #5104) - [x] Add a tutorial on the new concepts (#4927)
- [ ] Kernel fusion: Optimize aggregations, e.g., by computing multiple aggregations in parallel (at best discussed in a separate issue)
Any feedback and help from the community is highly appreciated!
cc: @lightaime @Padarn
Looks great @rusty1s! I'll try to pick up some of the smaller tasks this weekend
Added some simple ones. MaxAggr, MinAggr, SumAggr, SoftmaxAggr, PowermeanAggr, VarAggr, StdAggr.
I planned to pick up a couple of the tasks - hope you guys don't mind me editing the issue to make it clear what I plan on doing (I'll sick to smaller PRs if possible since I typically don't have much time during weeks)
Kernel fusion: Optimize aggregations, e.g., by computing multiple aggregations in parallel (at best discussed in a separate issue)
Do we have an open issue on this? I'd be interested to understand a bit more about what we're thinking here. Is it mostly for the case where we want to (for example) compute both a sum and mean?
Yes, indeed. There do not exist clear plans for implementation yet though. It will likely depend on PyTorch fusing this ops as part of TorchScript, or on us providing special CUDA kernels.
Thanks everyone for the hard work @lightaime @Padarn. I think that the final outcome looks fantastic - many cool things to promote in our upcoming release :)