`Batch.to_data_list()` raises after using `Batch.subgraph()` as it doesn't update `Batch.ptr`
🐛 Describe the bug
Description
When calling Batch.to_data_list() after using Batch.subgraph(), I get a RuntimeError because ptr (and possibly other internal states) was not properly updated.
Expectation
No error, correct sub-graphing
The subset I'm using is a property of the batch itself that was obtained during batching, so I expected
batch.subgraph(batch["key"])
that to do the same as
batch.from_data_list([
data.subgraph(data["key"])
for data in batch.to_data_list()
])
Minimal Reproductible Example
import torch
import torch_geometric.data as geom_data
g1 = geom_data.Data(
x=torch.tensor(list(range(10))),
is_pair=torch.tensor([i % 2 == 0 for i in range(10)], dtype=torch.bool),
)
g2 = geom_data.Data(
x=torch.tensor(list(range(10))) + 20,
is_pair=torch.tensor([i % 2 == 0 for i in range(10)], dtype=torch.bool),
)
batch = geom_data.Batch.from_data_list([g1, g2])
subbatch = batch.subgraph(batch["is_pair"])
subbatch # DataBatch(x=[10], is_pair=[10], batch=[10], ptr=[3]) # notice the batch now contains 10 nodes
subbatch.ptr # tensor([ 0, 10, 20]), which is wrong, the subbatch contains 10 nodes and not 20 since we sub-sampled
datas = batch.get_data_list() # RuntimeError: start (10) + length (10) exceeds dimension size (10).
It is possible other internal data of the batch isn't correct either, I'm not savvy enough with the underlying implementation to test that.
Environment
- PyG version: 2.4.0, with a backport of the fix for https://github.com/pyg-team/pytorch_geometric/issues/8272
- PyTorch version: 2.1.0
- OS: Unix
- Python version: 3.8
- CUDA/cuDNN version: 12.1
- How you installed PyTorch and PyG (
conda,pip, source):pip - Any other relevant information (e.g., version of
torch-scatter):torch-cluster 1.6.3+pt21cu121 torch_geometric 2.4.0 torch-scatter 2.1.2+pt21cu121 torchaudio 2.1.0 torchmetrics 1.2.0 torchvision 0.16.0
to_data_list() can currently only be used on unmodified Batch objects coming from from_data_list(). There is currently no easy way for us to fix this because subgraph() is a method at the Data level, not at the Batch level.
That make sense, thanks for the fast answer! I guess I'll have to un-batch and re-batch then.
Some questions:
- Given
subsetis aBoolTensorof sizebatch.num_nodes, is there any tricks that makes that possible in this specific case, or a list of internal states ofBatchI'd need to update to get it working? I'm not too savy with the internals ofBatchand what it keeps track of under the hood - Would an
subgraph()method onBatch(), to override the one fromData/HetereoDatamake sense? From my limited understanding of the codebase:- since
subgraph()selects a subset of nodes, what it is selecting is clear and this should be doable at the Batch level while preserving the batch information internal state (although not necessarily simple) - this would probably require two internal implementations, one for
Dataand the other forHeteoData
- since
Some additional thoughts which might be useful (or not).
It seems to me that this is made more complex because Data (and thus Batch) doesn't know the hyper-graph dimensionality of its Tensors, i.e. what their first index relates to: nodes, edges, 3-facet, something else?
- It has knowledge about
x, which it knows to be at the node level, because this is hard-coded - Similarly, it has knowledge about
edge_indexandedge_attr, which are at the 2-edge level, because this is hard-coded
But in general, given a stored tensor, it's not sure what the tensor is about.
A way to store that information would be to extend the edge_index and edge_attr relationship systematically:
-
indexesmap an integerN, the number of indices in the hyper-edge, to the tensor in COO format with shape(N, num_N_hyper_edges)that describe the existing hyper-edge of N nodes. -
attrorfeaturesmap an integerNas above to adict[str, Tensor]where each Tensor has size(num_N_hyper_edges, num_N_edge_features). The intermediate dict is just a way to give names to some features, e.g.poswould becomefeatures[1]["pos"]
A graph-level dict[str, Tensor] attribute would probably be useful.
It seems to me that if this information was preserved, then Data and Batch could properly handle any Tensor they store during subgraph operation, because they know how to subset them for the selected nodes (and obviously graph-level is preserved as-is), while also updating the internal state for un-batching.
This of course is quite theoretical thinking, and in practice would probably require a new API because it is not really compatible with the current Data and Batch objects.
For reference, I used the following code to unbatch, subgraph then re-batch:
if isinstance(graph, geom_data.Batch) and isinstance(graph, geom_data.Data):
# workaround for https://github.com/pyg-team/pytorch_geometric/issues/8439
# We need to unbatch & rebatch, and slice `subset` according to the individual graphs
num_nodes = graph.num_nodes
if num_nodes is None:
raise RuntimeError(
f"{type(graph).__name__} object has {num_nodes=}, cannot subgraph it"
)
if not hasattr(graph, "_slice_dict") or "x" not in graph._slice_dict:
raise RuntimeError(
f"Cannot find individual graph node slices in {type(graph).__name__} object"
)
else:
slices: torch.Tensor = graph._slice_dict["x"]
assert len(slices.shape) == 1
data_list = graph.to_data_list()
return graph.from_data_list(
[
subgraph(data, subset[start.item() : end.item()]) # type: ignore
for start, end, data in zip(slices[0:-1], slices[1:], data_list)
] # type: ignore
) # type: ignore
We provide Data.is_node_attr and Data.is_edge_attr functionality which will infer whether attributes are node-level or edge-level vectors.
For Batch.subgraph: Yes, I think this is possible, but we would require a mask as input here to guarantee that ordering is preserved. We can actually call super().subgraph() here and then adjust the slices dictionary to map to the correct new sizes of tensors.
So something like this:
- check that the
subsetinput has shape(self._num_graphs, )and dtypetorch.bool, ie that it is a boolean mask - call
super().subgraph() - adapt the
self._slide_dictby:- iterating on attributes, and:
- for a node attribute, if think
[0, *subset.cumsum()[slices[1:]]gives the new slices? - for an edge attribute I'm not too sure, we would probably need the
edge_maskreturned byutils.subgraph()and apply the same formula, but it's not provided bysuper().subgraph()
- for a node attribute, if think
- iterating on attributes, and:
- I think
self._inc_dictalso needs to be updated? I'm not quite clear on what it stores yet
I don't know what the best course of action is regarding initial Data that are completely un-selected and yield an empty slice: should they be dropped and self._num_graph be updated? Should they be kept as empty graphs?
Yes, you are right, we would also require the edge_mask, so we would need to re-structure Data.subgraph into:
def _subgraph(...):
return data, edge_mask
def subgraph(...):
return self._subgraph(...)[0]
Yes, both slice and inc needs to be updated. inc stores the incremental additions of edge_index (e.g., the number of nodes before the current graph in the mini-batch).