torch2trt
torch2trt copied to clipboard
[`torch.Tensor.__getitem__`] Add advanced indexing (GatherND) support to `torch.Tensor.__getitem__`
The following documentation also appears at the top of the file:
Our conversion of __getitem__ needs to handle basic and advanced indexing (specifically GatherND).
See the numpy description for more information on different types of indexing, which pytorch follows:
https://numpy.org/doc/stable/user/basics.indexing.html
We use the following terms to describe our algorithm:
t, a pytorch tensor of arbitrary shape and dimensions on which we are calling __getitem__.
s, a slice index; eg. the operators :, ..., None, ().
g, a gather index; eg. (x,...), [x,...], torch.tensor((x,...)) for any arbitrary scalar x.
Note that we currently only handle 1D gather indices, so g is always 1D where described.
Our algorithm works as follows:
For an input tensor t, we check the indices argument.
This results in the following cases:
1. If all of the indices are slices, eg. t[s,s,s,...], this is considered basic indexing,
and we can trivially convert this to TRT using the slice layer (along with some supporting layers).
2. If there are any gather indices, regardless of the presence of slice indices,
eg. t[...,g,g,g,...], this is now considered advanced indexing
and we are no longer just slicing, but also gathering on the input tensor.
We convert differently depending on the composition of the indices.
2a. If all of the indices are gather indices and there are no slice indices, eg. t[g,g,g,...],
then we can trivially convert this to TRT using a single gather layer.
2b. If we have a mix of slice and gather indices, eg. t[s,s,g,g,...], then the TRT conversion gets more complex.
First, we split the indices into slice only indices and gather only indices of the same dimensions,
using the colon operator for the axes where a gather or slice index was removed from the slice only
or gather only indices, respectively; this allows us to process the slice and gather indices separately,
where the colon operator allows us to ignore an axis when not processing that particular type of index.
Consequently, we can now process t as if the indices only have slice operations, eg. t[s,s,:,:,...],
using the same basic indexing methodology previously described in case (1) using a slice layer.
Afterwards, all slicing operations are complete and we need only perform gather operations henceforth.
Now using the output of the slice layer, we process all of the gather indices, eg. t[:,:,g,g,...].
As the TRT gather layer does not handle slice indices (ie. colon operators),
we cannot pass in all gather indices to the gather layer as in case (2a).
This is especially problematic when the colon operator sits between two gather operations, eg. t[g,:,g].
As a result, to account for these axes in which we have a colon operator,
we need to continually transpose (permute) t such that each axis that we are gathering on is adjacent,
until all axes on which we are gathering are adjacent; in other words, t[g,:,g] == transposed(t)[g,g,:]
is a valid equivalency (we call this coalescing gather indices for brevity).
This moves any dimensions with the colon operator out from between any two dimensions with gather operations
and allows us to use the TRT gather layer to perform the needed gatherND operation,
as now only gather indices are present in the indexing operation.
The following examples using a 4D tensor of shape (3,3,3,3) shows the equivalent transpose operations needed
so that all gather indices can be coalesced when indexing:
t[:,g,:,:] == t.transpose(1,0)[g].transpose(0,1)
t[:,:,g,:] == t.transpose(2,1).transpose(1,0)[g].transpose(0,1).transpose(1,2)
t[:,:,:,g] == t.transpose(3,2).transpose(2,1).transpose(1,0)[g].transpose(0,1).transpose(1,2).transpose(2,3)
t[g,:,g,:] == t.transpose(2,1)[g,g]
t[g,:,:,g] == t.transpose(3,2).transpose(2,1)[g,g]
t[:,g,g,:] == t.transpose(1,0).transpose(2,1)[g,g].transpose(0,1)
t[:,g,:,g] == t.transpose(1,0).transpose(3,2).transpose(2,1)[g,g]
t[:,:,g,g] == t.transpose(2,1).transpose(1,0).transpose(3,2).transpose(2,1)[g,g].transpose(0,1).transpose(1,2)
t[g,g,:,g] == t.transpose(3,2)[g,g,g]
t[g,:,g,g] == t.transpose(2,1).transpose(3,2)[g,g,g]
t[:,g,g,g] == t.transpose(1,0).transpose(2,1).transpose(3,2)[g,g,g].transpose(0,1)
Note the following from the above examples:
- The first gather operation always transposes to dimension 0, if it is not already there.
- Final transposes are needed after the gather operation iff gather indices are already coalesced together.
For posterity, here are some more examples of transposing different combinations of gatherND operations; the examples here are what's effectively implemented by the algorithm:
t2[:,x] == t2.transpose(1,0)[x].transpose(0,1)
t3[:,x,:] == t3.transpose(1,0)[x].transpose(0,1)
t3[:,:,x] == t3.transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2)
t3[x,:,x] == t3.transpose(2,1)[x,x]
t3[:,x,x] == t3.transpose(1,0).transpose(2,1)[x,x].transpose(0,1)
t4[:,x,:,:] == t4.transpose(1,0)[x].transpose(0,1)
t4[:,:,x,:] == t4.transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2)
t4[:,:,:,x] == t4.transpose(3,2).transpose(2,1).transpose(1,0)[x].transpose(0,1).transpose(1,2).transpose(2,3)
t4[x,:,x,:] == t4.transpose(1,2)[x,x]
t4[x,:,:,x] == t4.transpose(3,2).transpose(2,1)[x,x]
t4[:,x,x,:] == t4.transpose(1,0).transpose(2,1)[x,x].transpose(0,1)
t4[:,x,:,x] == t4.transpose(1,0).transpose(3,2).transpose(2,1)[x,x]
t4[:,:,x,x] == t4.transpose(2,1).transpose(1,0).transpose(3,2).transpose(2,1)[x,x].transpose(0,1).transpose(1,2)
t4[x,x,:,x] == t4.transpose(3,2)[x,x,x]
t4[x,:,x,x] == t4.transpose(2,1).transpose(3,2)[x,x,x]
t4[:,x,x,x] == t4.transpose(1,0).transpose(2,1).transpose(3,2)[x,x,x].transpose(0,1)
t5[x,:,x,:,:] == t5.transpose(2,1)[x,x]
t5[x,:,:,x,:] == t5.transpose(3,2).transpose(2,1)[x,x]
t5[x,:,:,:,x] == t5.transpose(4,3).transpose(3,2).transpose(2,1)[x,x]
t5[:,x,:,x,:] == t5.transpose(1,0).transpose(3,2).transpose(2,1)[x,x]
t5[:,:,x,:,x] == t5.transpose(2,1).transpose(1,0).transpose(4,3).transpose(3,2).transpose(2,1)[x,x]
t5[:,:,:,x,x] == t5.transpose(3,2).transpose(2,1).transpose(1,0).transpose(4,3).transpose(3,2).transpose(2,1)[x,x].transpose(0,1).transpose(1,2).transpose(2,3)
t5[x,x,:,x,:] == t5.transpose(3,2)[x,x,x]
t5[x,x,:,:,x] == t5.transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,x,x,:] == t5.transpose(2,1).transpose(3,2)[x,x,x]
t5[x,:,x,:,x] == t5.transpose(2,1).transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,:,x,x] == t5.transpose(3,2).transpose(2,1).transpose(4,3).transpose(3,2)[x,x,x]
t5[x,:,x,x,x] == t5.transpose(2,1).transpose(3,2).transpose(4,3)[x,x,x,x]
t5[x,0,x,x,:] == t5[:,0][x,x,x]
t5[x,0,:,x] == t5[:,0].transpose(2,1)[x,x]
t5[x,None,x] == t5[:,None].transpose(2,1)[x,x]
t5[x,None,:,x] == t5[:,None].transpose(3,2).transpose(2,1)[x,x]
WIP but somewhat working advanced indexing. Needs to be based off of https://github.com/NVIDIA-AI-IOT/torch2trt/pull/770
@jaybdub This is ready for review, but I may need to rebase... I started this before 0.4.0 was released, and certainly before the major API changes to master.
Note that this is dependent on #770
Also note that I've left in a couple of TODOs... I'll leave those for future PRs
NUM_SUCCESSFUL_CONVERSION: 88
NUM_FAILED_CONVERSION: 0
NUM_ABOVE_TOLERANCE: 0
NUM_pSNR_TOLERANCE: 0
From my local testing looks like this implementation should now correctly support dynamic shapes as well.
@jaybdub This PR is ready for review
Updated the algorithm to remove unnecessary gather operations... now only one gather operation is needed whereas multiple were used before. Besides being more efficient, this also side steps a problem we noticed where intermediate operations were failing due to large shapes produced.