pytorch_scatter
pytorch_scatter copied to clipboard
[Feature] Support for a scatter 'concatenate' or 'groupby' operation
Hi, thanks for the amazing work so far!
I was wondering if it would be possible to efficiently support a scatter operation that instead of reducing (e.g. using sum, mean, max, or min), simply returns the values indicated by the index.
For example, following the homepage illustration of this repo:
index = [0, 0, 1, 0, 2, 2, 3, 3]
input = [5, 1, 7, 2, 3, 2, 1, 3]
I would like to get an output similar to this:
0: [5, 1, 2]
1: [7]
2: [3, 2]
3: [1, 3]
(the order within each list would not matter)
I am not sure if I am missing something or if this is possible using existing operations. Perhaps the varying length is problematic, but this could be handled with nested tensors or padding. I would like to apply this operation several times per training epoch so ideally it would be efficient on GPUs.