pytorch_geometric
pytorch_geometric copied to clipboard
Weighted Aggregations and Implementation of `SoftMedian`
As discussed in #6867 and #6871 I added a WeighedAggregation interface. For now, I provide 4 aggregations:
-
WeightedMeanAggregation
-
WeighedQuantileAggregation
-
WeightedMedianAggregation
-
SoftMedianAggregation
(from paper Robustness of Graph Neural Networks at Scale)
I added an example with WeightedMeanAggregation
and SoftMedianAggregation
aggregations.
Is this how you @rusty1s imagined the extra interface? I am not sure if I like the distinction. It would be much nicer if there was a class inheritance, e.g., if Aggregation
was the superclass of WeightedAggregation
. However, this comes with some issues since the interfaces do not align. This is particularly an issue for using super().__call__
.
Once, we agree on the architecture and implementation details, I will finish the documentation and add some unit tests.
PS: I also added the topk
sparsification for GDC
- of course I could also move this to an extra pull request if desired.
@rusty1s I know you are quite busy. However, a kindly reminder that I am still waiting for your feedback on how to implement the WeightedAggregation
. Particularly, what class hierarchy would you prefer between Aggregation
and WeightedAggregation
?
Yeah, sorry for the delay here. I am on vacation this week and will try to get it in early next week.
Hi @sigeisler, made some changes and let WeightedAggregation
inherit from Aggregation
. Can you add some unit tests for the newly introduced weighted/soft median aggregations? Besides that, the PR seems ready to go.
For now, I removed integration in MessagePassing
- I think this is better handled separately to test it properly.
Codecov Report
Merging #7025 (561483f) into master (c5dca4b) will increase coverage by
0.01%
. The diff coverage is91.15%
.
@@ Coverage Diff @@
## master #7025 +/- ##
==========================================
+ Coverage 88.17% 88.18% +0.01%
==========================================
Files 471 472 +1
Lines 28274 28417 +143
==========================================
+ Hits 24930 25060 +130
- Misses 3344 3357 +13
Files | Coverage Δ | |
---|---|---|
torch_geometric/nn/aggr/__init__.py | 100.00% <100.00%> (ø) |
|
torch_geometric/nn/aggr/quantile.py | 100.00% <100.00%> (ø) |
|
torch_geometric/nn/aggr/base.py | 94.87% <95.83%> (+0.32%) |
:arrow_up: |
torch_geometric/nn/aggr/basic.py | 92.52% <83.33%> (-0.62%) |
:arrow_down: |
torch_geometric/transforms/gdc.py | 90.41% <92.85%> (+0.22%) |
:arrow_up: |
torch_geometric/nn/aggr/soft_median.py | 90.00% <90.00%> (ø) |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
Thank you for taking the time. I am planning to continue after the NeurIPS deadline
Thank you. Happy to get this in :)
Hi @rusty1s @wsad1 @EdisonLeeeee I apologize for the massive delay from start to finish of this pull request.
I have just fixed the last small things. I think this can be reviewed now.
Hi @rusty1s @EdisonLeeeee @wsad1
I know you are super busy with the many things you handle. However, I quickly wanted to check regarding the status.
Will try to get this in this week :) Thanks for the updates (missed the mail).
@EdisonLeeeee sorry for the delay. Can we do it the other way round? I just added this sparsification as it was missing/possible since PyTorch has a stable sort. However, the other feature is the one I am actually interested in having here.