pytorch_geometric
pytorch_geometric copied to clipboard
Creating Subgraphs based on edge/node type with HeteroData
🚀 The feature, motivation and pitch
Hi, I'm working on link prediction with a heterograph and trying to convert code in DGL to PyG. @rusty1s
- HeteroData with different types of edges, I want to get a subgraph that includes all nodes with that specific edge type. I wonder if PyG has a feature like this. DGL has dgl.edge_type_subgraph and I'd like to replicate this functionality in PyG. For example, if I want a subgraph in DGL I can do dgl.edge_type_subgraph(graph, [('author', 'writes', 'paper'),]) and it'd return a graph with all the 'writes' edge connections. PyG currently has a subgraph() function but it only works for nodes and you have to explicitly input which nodes you want to keep. Instead, I think with Heterodata you should also have this functionality for edges and have the option to just specify which node/edge type you want to keep.
- Quick accessing features- Is there a way to access all edges in a heterograph? DGL has g.edges() that can return a 2-tuple of 1D tensors (𝑈,𝑉), representing the source and destination nodes of all edges.
- In addition to your utils from_networkx(), a from_dgl() function would also be helpful :)
Alternatives
No response
Additional context
No response
Thanks for the Issue.
- This is a good idea. I think it should straightforward to add an argument
edge_listto subgraph and filter out edges not inedge_list. - We don't have the exact same function. You could get something similar with
HeteroData.edge_storesbut this gives you other edge attributes too, which you could filter out. - This is a nice idea. Not too familiar with DGL to make suggestion on how to go about it though.
Happy to accept PRs for 1 and 3.
- @wsad1 Do you wanna integrate it? :)
- We have
data.edge_index_dictanddata.to_homgeneous().edge_indexwhich should be exactly what you want. - I think this is a nice idea.
Great, thank you! Please keep me updated :) @wsad1 @rusty1s
Just added support for HeteroData.node_type_subgraph() and HeteroData.edge_type_subgraph().
Thanks. Let me know if the from_dgl ever gets integrated!
Please consider contributing it as well if I don't find time to do it :)