Is there a feature similar to node_subgraph in dgl
🚀 The feature, motivation and pitch
Background: In PyG, the subgraph function allows for extracting a subgraph based on a subset of nodes or edges. However, during subgraph extraction, the correspondence between the extracted subgraph and the original node features is not explicitly preserved. In contrast, DGL's node_subgraph function retains the original node features for the extracted subgraph, ensuring feature consistency.
Proposed Feature: I propose adding a similar feature to PyG, which retains the original node features while constructing a subgraph from a subset of nodes. The function should:
- Extract a subgraph based on a node subset.
- Maintain the correspondence of node features from the original graph.
- Optionally provide mappings between the original and subgraph node indices for reference.
# Given a graph with node features, graph is of typle Data
sub_nodes = torch.tensor([0, 2, 4]) # Selected nodes
subgraph = graph.node_subgraph(sub_nodes)
Benefits: Streamlined workflow for handling subgraphs while preserving node features. Enhanced compatibility with downstream tasks like node classification or graph neural networks where node features play a critical role.
References: DGL's node_subgraph documentation: https://docs.dgl.ai/en/0.9.x/generated/dgl.node_subgraph.html#dgl.node_subgraph
Thank you! This feature would improve the usability of PyG, especially for workflows involving complex feature mappings. Let me know if further clarification is needed.
Alternatives
No response
Additional context
No response
I guess this is what you need^^ https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.subgraph
Thanks for your reply. I have looked into it. But it only returns (edge_index, edge_attr), which seems missing node features.
You can also use Data.subgraph() or HeteroData.subgraph().
Hi, I found this to be slightly relevant to my problem, hence posting my issue here.
Is there a way to extract the graph nodes (as is - with original node ids and features) after running through a GNN that are common among the faulty edges as identified by a GNN. Let's say we have sufficient info to train a supervised edge classifier and do not have any data pointing to faulty nodes but we need to infer the nodes that are faulty from supervised edge level training. We want to be able to raise a fault where there are most occurrences among the nodes in the faulty routes. The graph is in a tree structure. Is there a way to cater this problem through a specific GNN type or is there any other work around. Keen to hear your thoughts.