pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

`Batch.to_data_list()` raises after using `Batch.subgraph()` as it doesn't update `Batch.ptr`

Open QuentinSoubeyranAqemia opened this issue 2 years ago • 6 comments

🐛 Describe the bug

Description

When calling Batch.to_data_list() after using Batch.subgraph(), I get a RuntimeError because ptr (and possibly other internal states) was not properly updated.

Expectation

No error, correct sub-graphing

The subset I'm using is a property of the batch itself that was obtained during batching, so I expected

batch.subgraph(batch["key"])

that to do the same as

batch.from_data_list([
    data.subgraph(data["key"])
    for data in batch.to_data_list()
])

Minimal Reproductible Example

import torch
import torch_geometric.data as geom_data

g1 = geom_data.Data(
    x=torch.tensor(list(range(10))),
    is_pair=torch.tensor([i % 2 == 0 for i in range(10)], dtype=torch.bool),
)
g2 = geom_data.Data(
    x=torch.tensor(list(range(10))) + 20,
    is_pair=torch.tensor([i % 2 == 0 for i in range(10)], dtype=torch.bool),
)
batch = geom_data.Batch.from_data_list([g1, g2])

subbatch = batch.subgraph(batch["is_pair"])
subbatch # DataBatch(x=[10], is_pair=[10], batch=[10], ptr=[3])  # notice the batch now contains 10 nodes
subbatch.ptr  # tensor([ 0, 10, 20]), which is wrong, the subbatch contains 10 nodes and not 20 since we sub-sampled

datas = batch.get_data_list()  # RuntimeError: start (10) + length (10) exceeds dimension size (10).

It is possible other internal data of the batch isn't correct either, I'm not savvy enough with the underlying implementation to test that.

Environment

  • PyG version: 2.4.0, with a backport of the fix for https://github.com/pyg-team/pytorch_geometric/issues/8272
  • PyTorch version: 2.1.0
  • OS: Unix
  • Python version: 3.8
  • CUDA/cuDNN version: 12.1
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):
    torch-cluster            1.6.3+pt21cu121
    torch_geometric          2.4.0
    torch-scatter            2.1.2+pt21cu121
    torchaudio               2.1.0
    torchmetrics             1.2.0
    torchvision              0.16.0
    

QuentinSoubeyranAqemia avatar Nov 24 '23 16:11 QuentinSoubeyranAqemia

to_data_list() can currently only be used on unmodified Batch objects coming from from_data_list(). There is currently no easy way for us to fix this because subgraph() is a method at the Data level, not at the Batch level.

rusty1s avatar Nov 26 '23 10:11 rusty1s

That make sense, thanks for the fast answer! I guess I'll have to un-batch and re-batch then.

Some questions:

  • Given subset is a BoolTensor of size batch.num_nodes, is there any tricks that makes that possible in this specific case, or a list of internal states of Batch I'd need to update to get it working? I'm not too savy with the internals of Batch and what it keeps track of under the hood
  • Would an subgraph() method on Batch(), to override the one from Data/HetereoData make sense? From my limited understanding of the codebase:
    • since subgraph() selects a subset of nodes, what it is selecting is clear and this should be doable at the Batch level while preserving the batch information internal state (although not necessarily simple)
    • this would probably require two internal implementations, one for Data and the other for HeteoData

Some additional thoughts which might be useful (or not).

It seems to me that this is made more complex because Data (and thus Batch) doesn't know the hyper-graph dimensionality of its Tensors, i.e. what their first index relates to: nodes, edges, 3-facet, something else?

  • It has knowledge about x, which it knows to be at the node level, because this is hard-coded
  • Similarly, it has knowledge about edge_index and edge_attr, which are at the 2-edge level, because this is hard-coded

But in general, given a stored tensor, it's not sure what the tensor is about.

A way to store that information would be to extend the edge_index and edge_attr relationship systematically:

  • indexes map an integer N, the number of indices in the hyper-edge, to the tensor in COO format with shape (N, num_N_hyper_edges) that describe the existing hyper-edge of N nodes.
  • attr or features map an integer N as above to a dict[str, Tensor] where each Tensor has size (num_N_hyper_edges, num_N_edge_features). The intermediate dict is just a way to give names to some features, e.g. pos would become features[1]["pos"]

A graph-level dict[str, Tensor] attribute would probably be useful.

It seems to me that if this information was preserved, then Data and Batch could properly handle any Tensor they store during subgraph operation, because they know how to subset them for the selected nodes (and obviously graph-level is preserved as-is), while also updating the internal state for un-batching.

This of course is quite theoretical thinking, and in practice would probably require a new API because it is not really compatible with the current Data and Batch objects.

QuentinSoubeyranAqemia avatar Nov 27 '23 09:11 QuentinSoubeyranAqemia

For reference, I used the following code to unbatch, subgraph then re-batch:

if isinstance(graph, geom_data.Batch) and isinstance(graph, geom_data.Data):
    # workaround for https://github.com/pyg-team/pytorch_geometric/issues/8439
    # We need to unbatch & rebatch, and slice `subset` according to the individual graphs
    num_nodes = graph.num_nodes
    if num_nodes is None:
        raise RuntimeError(
            f"{type(graph).__name__} object has {num_nodes=}, cannot subgraph it"
        )
    if not hasattr(graph, "_slice_dict") or "x" not in graph._slice_dict:
        raise RuntimeError(
            f"Cannot find individual graph node slices in {type(graph).__name__} object"
        )
    else:
        slices: torch.Tensor = graph._slice_dict["x"]
    assert len(slices.shape) == 1
    data_list = graph.to_data_list()

    return graph.from_data_list(
        [
            subgraph(data, subset[start.item() : end.item()])  # type: ignore
            for start, end, data in zip(slices[0:-1], slices[1:], data_list)
        ]  # type: ignore
    )  # type: ignore

QuentinSoubeyranAqemia avatar Nov 27 '23 15:11 QuentinSoubeyranAqemia

We provide Data.is_node_attr and Data.is_edge_attr functionality which will infer whether attributes are node-level or edge-level vectors.

For Batch.subgraph: Yes, I think this is possible, but we would require a mask as input here to guarantee that ordering is preserved. We can actually call super().subgraph() here and then adjust the slices dictionary to map to the correct new sizes of tensors.

rusty1s avatar Nov 28 '23 14:11 rusty1s

So something like this:

  • check that the subset input has shape (self._num_graphs, ) and dtype torch.bool, ie that it is a boolean mask
  • call super().subgraph()
  • adapt the self._slide_dict by:
    • iterating on attributes, and:
      • for a node attribute, if think [0, *subset.cumsum()[slices[1:]] gives the new slices?
      • for an edge attribute I'm not too sure, we would probably need the edge_mask returned by utils.subgraph() and apply the same formula, but it's not provided by super().subgraph()
  • I think self._inc_dict also needs to be updated? I'm not quite clear on what it stores yet

I don't know what the best course of action is regarding initial Data that are completely un-selected and yield an empty slice: should they be dropped and self._num_graph be updated? Should they be kept as empty graphs?

QuentinSoubeyranAqemia avatar Nov 28 '23 15:11 QuentinSoubeyranAqemia

Yes, you are right, we would also require the edge_mask, so we would need to re-structure Data.subgraph into:

def _subgraph(...):
    return data, edge_mask
    
def subgraph(...):
     return self._subgraph(...)[0]

Yes, both slice and inc needs to be updated. inc stores the incremental additions of edge_index (e.g., the number of nodes before the current graph in the mini-batch).

rusty1s avatar Nov 29 '23 09:11 rusty1s