`append` method for `Batch`
🚀 The feature, motivation and pitch
Add an append method to the Batch class that takes as argument the names of the store and attribute, a feature tensor and a batch tensor, that appends the tensor to the batch and also updates the _slice_dict and _inc_dict so the tensor is split correctly when to_data_list() is called on the batch.
Per @rusty1s, this is not a trivial request as it would require a refactor of the collate method. I'm happy to contribute to this refactor – I've worked with PyG as a user for a long time, so I have at least a surface-level understanding of how things work under the hood.
Alternatives
No response
Additional context
I develop a GNN architecture that operates by appending output tensors onto the input graph. When a batch is pushed thru the forward pass of the model, we currently have to utilise a pretty hacky workaround in order to append the output tensors such that they get correctly split when running analysis on the output:
# append output tensors back onto input data object
if isinstance(data, Batch):
dlist = [ HeteroData() for i in range(data.num_graphs) ]
for attr, planes in x.items():
for p, t in planes.items():
if t.size(0) == data[p].num_nodes:
tlist = unbatch(t, data[p].batch)
elif t.size(0) == data.num_graphs:
tlist = unbatch(t, torch.arange(data.num_graphs))
else:
print(f'don\'t know how to unbatch attribute {attr}')
exit()
for it_d, it_t in zip(dlist, tlist):
it_d[p][attr] = it_t
tmp = Batch.from_data_list(dlist)
data.update(tmp)
for attr, planes in x.items():
for p in planes:
data._slice_dict[p][attr] = tmp._slice_dict[p][attr]
data._inc_dict[p][attr] = tmp._inc_dict[p][attr]
The addition of an append function would mean we no longer need this workaround.
Thanks for creating this issue. I think this would be a pretty cool feature. Two comments:
- Currently,
collateoperates on a per attribute level (it computes the mini-batch feature by collecting the attribute from all graphs), while here,collateshould operate on a per example level. - There exists various ways to implement this. Re-factor
collatecompletely or provide a separate routine in addition to the currentcollatefunction to handle this. Looking at the current code, I am probably more in favor of doing this in a separate routine, but we would still need to find a good set of functionalities we can borrow fromcollateand hence don't re-write twice.
+1
@rusty1s - are you still interested in this? we (i.e., @younik and the torchgfn team) have an idea of how to build it - though it could likely be improved, we have a working extension of GeometricBatch here:
https://github.com/GFNOrg/torchgfn/blob/e93395a021166c7be1059cf66284c6146a16cc40/src/gfn/utils/graphs.py#L176