TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

select_scatter decomp

Open apbose opened this issue 1 year ago • 4 comments

Fixes #2436 This PR would be dependant on https://github.com/pytorch/TensorRT/pull/2519, https://github.com/pytorch/TensorRT/pull/2664 and https://github.com/pytorch/TensorRT/pull/2669. Major changes

2519- Decomposition of aten::slice_scatter 2664- Implementation makes use of aten::scatter.src 2669- Constants getting converted to fake tensors in get_attr call due to which different device location meta and cpu in torch

apbose avatar Dec 05 '23 22:12 apbose

See this decomposition for an alternative approach.

gs-olive avatar Mar 20 '24 01:03 gs-olive

Thanks @gs-olive for pointing the above. But I think the implementation using slice_scatter decomposition should also work in our case. For eg: in the above case the unsqueeze dimension with src would lead to src_tensor being [1,2] (torch.slice_scatter would expect it to be [0,2]. But since in slice_scatter decomposition, we do away with dimension at dim (dim=0) in this case the above error would not come there. Also in more than two dimensions this case should never be encountered

apbose avatar Mar 26 '24 20:03 apbose

So, in this case would the implementation not be functional without the slice_scatter decomposition?

Additionally, if the slice_scatter decomposition changes the behavior of torch.slice_scatter, in the sense that the example here (https://github.com/pytorch/TensorRT/pull/2515#pullrequestreview-1947797291) passes with the decomposition but fails without it, then how does the slice_scatter decomposition change the operator? I thought the decomposition would be 1:1 with the operator meaning any inputs to the operator are valid inputs to the decomposition and vice versa.

gs-olive avatar Mar 27 '24 22:03 gs-olive

I misread the case pointed by you.

>>> import torch
>>> a = torch.zeros(2, 2)
>>> b = torch.ones(2)
>>> torch.select_scatter(a, b, 0, 0)
tensor([[1., 1.],
        [0., 0.]])
>>> torch.slice_scatter(a, b.unsqueeze(0), 0, 1, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: expected src to have a size equal to the slice of self. src size = [1, 2], slice size = [0, 2]

In the above according to the implementation above, the slice_scatter op would be torch.slice_scatter(a, b.unsqueeze(0), 0, 0, 1, 1) Which would lead to the slice dimension being [1,2] The difference between slice_scatter and select_scatter is that since select_scatter inserts a single dimensional tensor, while inputting the src tensor we generally just provide single dimension along the dim. Example: for torch.select_scatter=> for input tensor torch.zeros((2,2)) with shape = [2,2] ,the src tensor should be torch.ones(2) with shape [2] instead of [1,2] for dim = 1 for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2,4) with shape [2,4] for dim = 1 The op would be torch.select_scatter(input, src, dim, index)

torch.slice_scatter=> for input tensor torch.zeros((2,2)) with shape = [2,2], the src tensor should be torch.ones(1,2) with shape [1,2] for dim = 1 for input tensor torch.zeros((2,3,4)) with shape = [2,3,4], the src tensor should be torch.ones(2, 1, 4) with shape [2,4] for dim = 1 The op would be torch.slice_scatter(input, src, dim, index, index+1)

To answer the above question-

  1. No slice_scatter decomposition does not alter the torch slice_scatter behavior. It is 1:1 behavior
  2. My earlier comment was because I misunderstood. The torch.slice_scatter would also expect it to be [0,1], since the slice_scatter is for torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1) (index to index+1). But as mentioned in the above comment, since we do away with the unsqueezed dimension in slice_scatter, I partially misunderstood the example of the slice_scatter op and thought that would work (basically the slice_scatter op would not come in the first place).

apbose avatar Apr 02 '24 03:04 apbose