pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

Is there a feature similar to node_subgraph in dgl

Open Ziyang-Yu opened this issue 1 year ago • 4 comments

🚀 The feature, motivation and pitch

Background: In PyG, the subgraph function allows for extracting a subgraph based on a subset of nodes or edges. However, during subgraph extraction, the correspondence between the extracted subgraph and the original node features is not explicitly preserved. In contrast, DGL's node_subgraph function retains the original node features for the extracted subgraph, ensuring feature consistency.

Proposed Feature: I propose adding a similar feature to PyG, which retains the original node features while constructing a subgraph from a subset of nodes. The function should:

  1. Extract a subgraph based on a node subset.
  2. Maintain the correspondence of node features from the original graph.
  3. Optionally provide mappings between the original and subgraph node indices for reference.
# Given a graph with node features, graph is of typle Data
sub_nodes = torch.tensor([0, 2, 4])  # Selected nodes
subgraph = graph.node_subgraph(sub_nodes)

Benefits: Streamlined workflow for handling subgraphs while preserving node features. Enhanced compatibility with downstream tasks like node classification or graph neural networks where node features play a critical role.

References: DGL's node_subgraph documentation: https://docs.dgl.ai/en/0.9.x/generated/dgl.node_subgraph.html#dgl.node_subgraph

Thank you! This feature would improve the usability of PyG, especially for workflows involving complex feature mappings. Let me know if further clarification is needed.

Alternatives

No response

Additional context

No response

Ziyang-Yu avatar Nov 21 '24 19:11 Ziyang-Yu

I guess this is what you need^^ https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.subgraph

xnuohz avatar Nov 22 '24 09:11 xnuohz

Thanks for your reply. I have looked into it. But it only returns (edge_index, edge_attr), which seems missing node features.

Ziyang-Yu avatar Nov 22 '24 20:11 Ziyang-Yu

You can also use Data.subgraph() or HeteroData.subgraph().

rusty1s avatar Dec 09 '24 09:12 rusty1s

Hi, I found this to be slightly relevant to my problem, hence posting my issue here.

Is there a way to extract the graph nodes (as is - with original node ids and features) after running through a GNN that are common among the faulty edges as identified by a GNN. Let's say we have sufficient info to train a supervised edge classifier and do not have any data pointing to faulty nodes but we need to infer the nodes that are faulty from supervised edge level training. We want to be able to raise a fault where there are most occurrences among the nodes in the faulty routes. The graph is in a tree structure. Is there a way to cater this problem through a specific GNN type or is there any other work around. Keen to hear your thoughts.

Tan-zee-la avatar Mar 17 '25 20:03 Tan-zee-la