TensorRT
TensorRT copied to clipboard
select_scatter decomp
Fixes #2436 This PR would be dependant on https://github.com/pytorch/TensorRT/pull/2519, https://github.com/pytorch/TensorRT/pull/2664 and https://github.com/pytorch/TensorRT/pull/2669. Major changes
2519- Decomposition of aten::slice_scatter 2664- Implementation makes use of aten::scatter.src 2669- Constants getting converted to fake tensors in get_attr call due to which different device location meta and cpu in torch
See this decomposition for an alternative approach.
Thanks @gs-olive for pointing the above.
But I think the implementation using slice_scatter decomposition should also work in our case.
For eg: in the above case the unsqueeze dimension with src would lead to src_tensor being [1,2] (torch.slice_scatter would expect it to be [0,2]. But since in slice_scatter decomposition, we do away with dimension at dim (dim=0) in this case the above error would not come there. Also in more than two dimensions this case should never be encountered
So, in this case would the implementation not be functional without the slice_scatter decomposition?
Additionally, if the slice_scatter decomposition changes the behavior of torch.slice_scatter, in the sense that the example here (https://github.com/pytorch/TensorRT/pull/2515#pullrequestreview-1947797291) passes with the decomposition but fails without it, then how does the slice_scatter decomposition change the operator? I thought the decomposition would be 1:1 with the operator meaning any inputs to the operator are valid inputs to the decomposition and vice versa.
I misread the case pointed by you.
>>> import torch
>>> a = torch.zeros(2, 2)
>>> b = torch.ones(2)
>>> torch.select_scatter(a, b, 0, 0)
tensor([[1., 1.],
[0., 0.]])
>>> torch.slice_scatter(a, b.unsqueeze(0), 0, 1, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: expected src to have a size equal to the slice of self. src size = [1, 2], slice size = [0, 2]
In the above according to the implementation above, the slice_scatter op would be
torch.slice_scatter(a, b.unsqueeze(0), 0, 0, 1, 1)
Which would lead to the slice dimension being [1,2]
The difference between slice_scatter and select_scatter is that since select_scatter inserts a single dimensional tensor, while inputting the src tensor we generally just provide single dimension along the dim.
Example: for
torch.select_scatter=>
for input tensor torch.zeros((2,2)) with shape = [2,2] ,the src tensor should be torch.ones(2) with shape [2] instead of [1,2] for dim = 1
for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2,4) with shape [2,4] for dim = 1
The op would be torch.select_scatter(input, src, dim, index)
torch.slice_scatter=>
for input tensor torch.zeros((2,2)) with shape = [2,2], the src tensor should be torch.ones(1,2) with shape [1,2] for dim = 1
for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2, 1, 4) with shape [2,4] for dim = 1
The op would be torch.slice_scatter(input, src, dim, index, index+1)
To answer the above question-
- No
slice_scatterdecomposition does not alter the torch slice_scatter behavior. It is 1:1 behavior - My earlier comment was because I misunderstood. The
torch.slice_scatterwould also expect it to be[0,1], since the slice_scatter is fortorch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)(index to index+1). But as mentioned in the above comment, since we do away with the unsqueezed dimension inslice_scatter, I partially misunderstood the example of theslice_scatterop and thought that would work (basically theslice_scatterop would not come in the first place).