pytorch_geometric icon indicating copy to clipboard operation
pytorch_geometric copied to clipboard

`append` method for `Batch`

Open vhewes opened this issue 2 years ago • 3 comments

🚀 The feature, motivation and pitch

Add an append method to the Batch class that takes as argument the names of the store and attribute, a feature tensor and a batch tensor, that appends the tensor to the batch and also updates the _slice_dict and _inc_dict so the tensor is split correctly when to_data_list() is called on the batch.

Per @rusty1s, this is not a trivial request as it would require a refactor of the collate method. I'm happy to contribute to this refactor – I've worked with PyG as a user for a long time, so I have at least a surface-level understanding of how things work under the hood.

Alternatives

No response

Additional context

I develop a GNN architecture that operates by appending output tensors onto the input graph. When a batch is pushed thru the forward pass of the model, we currently have to utilise a pretty hacky workaround in order to append the output tensors such that they get correctly split when running analysis on the output:

        # append output tensors back onto input data object
        if isinstance(data, Batch):
            dlist = [ HeteroData() for i in range(data.num_graphs) ]
            for attr, planes in x.items():
                for p, t in planes.items():
                    if t.size(0) == data[p].num_nodes:
                        tlist = unbatch(t, data[p].batch)
                    elif t.size(0) == data.num_graphs:
                        tlist = unbatch(t, torch.arange(data.num_graphs))
                    else:
                        print(f'don\'t know how to unbatch attribute {attr}')
                        exit()
                    for it_d, it_t in zip(dlist, tlist):
                        it_d[p][attr] = it_t
            tmp = Batch.from_data_list(dlist)
            data.update(tmp)
            for attr, planes in x.items():
                for p in planes:
                    data._slice_dict[p][attr] = tmp._slice_dict[p][attr]
                    data._inc_dict[p][attr] = tmp._inc_dict[p][attr]

The addition of an append function would mean we no longer need this workaround.

vhewes avatar Aug 30 '23 16:08 vhewes

Thanks for creating this issue. I think this would be a pretty cool feature. Two comments:

  • Currently, collate operates on a per attribute level (it computes the mini-batch feature by collecting the attribute from all graphs), while here, collate should operate on a per example level.
  • There exists various ways to implement this. Re-factor collate completely or provide a separate routine in addition to the current collate function to handle this. Looking at the current code, I am probably more in favor of doing this in a separate routine, but we would still need to find a good set of functionalities we can borrow from collate and hence don't re-write twice.

rusty1s avatar Sep 01 '23 08:09 rusty1s

+1

younik avatar Apr 18 '25 13:04 younik

@rusty1s - are you still interested in this? we (i.e., @younik and the torchgfn team) have an idea of how to build it - though it could likely be improved, we have a working extension of GeometricBatch here:

https://github.com/GFNOrg/torchgfn/blob/e93395a021166c7be1059cf66284c6146a16cc40/src/gfn/utils/graphs.py#L176

josephdviviano avatar Sep 29 '25 17:09 josephdviviano