pytorch_geometric
pytorch_geometric copied to clipboard
Support multiple node type sampling in `NeighborLoader`
This PR adds functionality to allow for multiple node types to be sampled in NeighbourLoader.
The interface looks as was discussed in the roadmap (https://github.com/pyg-team/pytorch_geometric/issues/4765):
NeighbourLoader(
input_nodes=[
('paper', torch.LongTensor([0,1,2])),
('author', torch.LongTensor([0,1,2]))
]
...
)
Internally, it converts this to a list of tuples.
[('paper', 0), ('paper', 1),....]
This is not very efficient, but benchmarks https://github.com/pyg-team/pytorch_geometric/issues/4765#issuecomment-1147105584 showed it to be acceptable.
TODO:
- [x] Add tests
- [x] Add support for
Noneinstead of providing specific nodes for some node types
Addresses https://github.com/pyg-team/pytorch_geometric/issues/4765
Thanks. FYI I'll also do this for the link loader and add examples in separate PRs, so we can iterate on the complexity later too if it seems okay but not perfect.
Hey @mananshah99 do you also want to take a look at this one? May need to merge it with https://github.com/pyg-team/pytorch_geometric/pull/5312
Codecov Report
Merging #5013 (6ce3cf0) into master (79617e0) will decrease coverage by
1.94%. The diff coverage is86.00%.
:exclamation: Current head 6ce3cf0 differs from pull request most recent head 9af624f. Consider uploading reports for the commit 9af624f to get more accurate results
@@ Coverage Diff @@
## master #5013 +/- ##
==========================================
- Coverage 85.27% 83.32% -1.95%
==========================================
Files 338 338
Lines 18683 18709 +26
==========================================
- Hits 15931 15590 -341
- Misses 2752 3119 +367
| Impacted Files | Coverage Δ | |
|---|---|---|
| torch_geometric/data/lightning_datamodule.py | 48.82% <ø> (ø) |
|
| torch_geometric/loader/neighbor_loader.py | 92.22% <84.78%> (-2.55%) |
:arrow_down: |
| torch_geometric/typing.py | 100.00% <100.00%> (ø) |
|
| torch_geometric/nn/models/dimenet_utils.py | 0.00% <0.00%> (-75.52%) |
:arrow_down: |
| torch_geometric/nn/models/dimenet.py | 14.51% <0.00%> (-53.00%) |
:arrow_down: |
| torch_geometric/profile/profile.py | 37.89% <0.00%> (-26.32%) |
:arrow_down: |
| torch_geometric/nn/conv/utils/typing.py | 81.25% <0.00%> (-17.50%) |
:arrow_down: |
| torch_geometric/nn/inits.py | 67.85% <0.00%> (-7.15%) |
:arrow_down: |
| torch_geometric/transforms/add_self_loops.py | 94.44% <0.00%> (-5.56%) |
:arrow_down: |
| torch_geometric/nn/resolver.py | 88.88% <0.00%> (-5.56%) |
:arrow_down: |
| ... and 12 more |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
Thanks for the comments @mananshah99. Sorry I've been busy with my day job, will review and address over the weekend.
On second look I think I'll wait for you to finish your current refactoring PRs, the code has changed a lot and I'll have to fit into the new interface. Will focus on helping review you PRs first and then rework this one.
On second look I think I'll wait for you to finish your current refactoring PRs, the code has changed a lot and I'll have to fit into the new interface. Will focus on helping review you PRs first and then rework this one.
Thank you for accommodating :) The refactoring PRs are complete now, and the interface is mostly stable. Happy to help move this implementation over behind the new interface, it's pretty cool.
Great! I can refactor later this week. I'll probably start a new PR as I think most of the code need to move, will definitely ask you for a review.