burn icon indicating copy to clipboard operation
burn copied to clipboard

Splitting `uncat` tensor over an index rensor operation

Open miestrode opened this issue 2 years ago • 6 comments

Really hope this isn't a duplicate, or already implemented as I still consider myself a newbie

Feature description

This operation is essentially like an inverse to concatenation. The simplest version of this operation would be over a single index, being the inverse to a concatenation of two tensors.

Example

Given the 2x3 tensor:

[ 1 2 3 ]
[ 4 5 6 ]

One can split it over the second dimension (1) at the index 1, which would result in you getting two new tensors: left and right, with left being:

[ 1 ]
[ 4 ]

And right being:

[ 2 3 ]
[ 5 6 ]

Notice how the index is non-inclusive, meaning that concatenation over the index 0 or the index 2 would result in an empty tensor and the original tensor, in this case. We could also generalize this to having N indices and being returned N + 1 tensors.

Feature motivation

Beyond just being useful as an inverse, unless I'm misusing things, I've found that for multi-headed networks, in which parts of the output tensor need to be interpreted differently could benefit from this kind of operation. It's possible to do something like this using slicing, but it requires cloning.

miestrode avatar Oct 14 '23 12:10 miestrode

We added a chunk operator (see https://github.com/Tracel-AI/burn/issues/970). Would this solve the issue?

antimora avatar Nov 29 '23 22:11 antimora

uncat is not an op I've seen, though pytorch does have split() which would also yield the result you want. However, in Burn you can achieve the outcome you want with the following:

let tensor: Tensor<TestBackend, 2> = Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.]]));
let [n, c] = tensor.dims();

let chunk1 = tensor.clone().slice([0..n, 0..1]);
let chunk2 = tensor.slice([0..n, 1..c]);

EDIT: I just read your comment about using slicing and cloning. You don't to be concerned with cloning. From the Burn book:

Please take note that tensor operations receive owned tensors as input. For reusing a tensor multiple times, you need to use the clone function. There's no need to worry; this process won't involve actual copying of the tensor data. Instead, it will simply indicate that the tensor is employed in multiple instances, implying that certain operations won't be performed in place. In summary, our API has been designed with owned tensors to optimize performance.

dcvz avatar Nov 30 '23 12:11 dcvz

We added a chunk operator (see #970). Would this solve the issue?

No, I don't believe this would solve things. Chunking is a special case of uncat, and the fine-grained control is needed sometimes.

uncat is not an op I've seen, though pytorch does have split() which would also yield the result you want. However, in Burn you can achieve the outcome you want with the following:

let tensor: Tensor<TestBackend, 2> = Tensor::from_data(Data::from([[1., 2., 3.], [4., 5., 6.]]));
let [n, c] = tensor.dims();

let chunk1 = tensor.clone().slice([0..n, 0..1]);
let chunk2 = tensor.slice([0..n, 1..c]);

Yeah, I'm aware you can currently do this in Burn. I think the current solution I've had was a bit different, but that's besides the point. I think this kind of operation is useful enough to be a standard operation.

I've refrained from contributing in the past as I thought every single operation would be best made using a purpose built kernel, but looking at the past commits, if it's okay to implement this operation with existing operations, I might just do it.

miestrode avatar Nov 30 '23 13:11 miestrode

Anyways, yeah I do understand cloning isn't really costly, but it isn't really the prettiest. Even with this though, this isn't ideal, I would say.

miestrode avatar Nov 30 '23 13:11 miestrode

@miestrode I updated my comment to mention that you don't have to fear cloning in Burn:

Please take note that tensor operations receive owned tensors as input. For reusing a tensor multiple times, you need to use the clone function. There's no need to worry; this process won't involve actual copying of the tensor data. Instead, it will simply indicate that the tensor is employed in multiple instances, implying that certain operations won't be performed in place. In summary, our API has been designed with owned tensors to optimize performance.

But as mentioned, pytorch does have a split method which could also be implemented in Burn using other primitives like slice(), reshape()

dcvz avatar Nov 30 '23 13:11 dcvz

split seems nice and also has the benefit of being an existing operation. Despitr this, I do kind of think a slicing centric API is nicer, especially since it doesn't require knowing information on the exact shape of the tensor.

I guess it should be up to the maintainers to decide if something as described here, or split, would be a better fit.

miestrode avatar Nov 30 '23 13:11 miestrode