torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

[RFC] Shape handling for output lists of tensors

Open qedawkins opened this issue 3 years ago • 8 comments

For split-like ops the output type is a list of (potentially different shape) tensors which is not able to be handled by the current shape refinement pass. The solution introduced in #820 was to do an early decomposition pass that would happen before shape refinement, however this solution poses a problem with static shapes. In particular for split we have the following:

def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
    input_sizes = self.shape
    dim_size = input_sizes[dim]
    if split_size == 0:
        assert(dim_size == 0)
        return [self]
    chunks = (dim_size + split_size - 1) // split_size
    split_sizes = [split_size for i in range(chunks)]
    split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
    return aten.split_with_sizes(self, split_sizes, dim)

Where we are relying on knowing split_size statically which we often don't know until shape refinement. This leads to a catch-22 where we can't lower early because we don't have shape refinement but we have to lower early because shape refinement doesn't work for lists of tensors.

My understanding for the way that lists of tensors currently works in the shape refinement pass (e.g. aten.cat takes a list of tensors as an argument) is that we look at the list of users for a PrimListConstructOp and assume new lists will be constructed for each mutating user. We then always have a set of SSA values to work with when computing the shape function of something like aten.cat. The issue then is that for the decomposition of split ops there isn't a PrimListConstructOp to do this with until the decomposition happens.

My thinking is that this should still be something the shape library can deal with by looking into the result of the shape calculation directly (although I don't know exactly what the implementation would look like). For the case of split ops, if the split size is static then after simplifying the shape function we should hopefully know the result size of the list and can infer a PrimListConstructOp will be filled in there. If this doesn't sound feasible, any opportunity to correct my understanding is greatly appreciated!

cc @silvasean @ramiro050 @JakopinA

qedawkins avatar Jul 28 '22 19:07 qedawkins

There isn't an easy answer here if the number of splits isn't known when we do the early decomposition. It creates a catch-22 as you say. Can you link the shape functions for split?

silvasean avatar Jul 28 '22 20:07 silvasean

This upstream PyTorch PR includes the shape functions: https://github.com/pytorch/pytorch/pull/79194/files. The main problem function is this one for aten.split.Tensor:

def split_tensor(self: List[int], split_size: int, dim: int = 0) -> List[List[int]]:
    dim_size = self[dim]
    chunks = (dim_size + split_size - 1) // split_size
    split_sizes = [split_size for i in range(chunks)]
    split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
    return split_fn(self, split_sizes, dim)

Where it depends on the shape at dim. I was more wondering if this was possible if we only supported the case where this is static. The early decomposition in #820 and #1056 both only support this size being static, so I was wondering if we could support this static case but at the normal decomposition time while adding a failure case to shape refinement in the event that this size is dynamic.

qedawkins avatar Jul 28 '22 20:07 qedawkins

Can you show the IR before and after the early decomposition pass for the cases that work now?

silvasean avatar Jul 28 '22 22:07 silvasean

For split.Tensor, the way #820 found a static size at least some of the time is by trying to find a defining aten.view op, for example see the following test case:

class SplitTensorModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([-1, -1, -1], torch.float32, True),
    ])  
    def forward(self, x): 
        split_size = 2 
        dim = 2 
        splits = torch.ops.aten.split(x.view(9, 6, 8), split_size, dim)
        return splits[0], splits[1], splits[2], splits[3]

@register_test_case(module_factory=lambda: SplitTensorModule())
def SplitTensorModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(9, 6, 8)) 

The IR before DecomposeEarly for this test is:

module attributes {torch.debug_module_name = "SplitTensorModule"} {
  func.func @forward(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) {
    %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
    %1 = torch.copy.to_tensor %0 : !torch.tensor
    %int2 = torch.constant.int 2
    %int9 = torch.constant.int 9
    %int6 = torch.constant.int 6
    %int8 = torch.constant.int 8
    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int3 = torch.constant.int 3
    %2 = torch.prim.ListConstruct %int9, %int6, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %3 = torch.aten.view %1, %2 : !torch.tensor, !torch.list<int> -> !torch.tensor
    %4 = torch.aten.split.Tensor %3, %int2, %int2 : !torch.tensor, !torch.int, !torch.int -> !torch.list<tensor>
    %5 = torch.aten.__getitem__.t %4, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
    %6 = torch.aten.__getitem__.t %4, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
    %7 = torch.aten.__getitem__.t %4, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
    %8 = torch.aten.__getitem__.t %4, %int2 : !torch.list<tensor>, !torch.int -> !torch.tensor
    %9 = torch.aten.__getitem__.t %4, %int3 : !torch.list<tensor>, !torch.int -> !torch.tensor
    %10 = torch.prim.TupleConstruct %6, %7, %8, %9 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor, tensor, tensor>
    %int0_0 = torch.constant.int 0
    %11 = torch.prim.TupleIndex %10, %int0_0 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
    %int1_1 = torch.constant.int 1
    %12 = torch.prim.TupleIndex %10, %int1_1 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
    %int2_2 = torch.constant.int 2
    %13 = torch.prim.TupleIndex %10, %int2_2 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
    %int3_3 = torch.constant.int 3
    %14 = torch.prim.TupleIndex %10, %int3_3 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
    return %11, %12, %13, %14 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor
  }
}

and the IR after

// -----// IR Dump After DecomposeComplexOpsEarly (torch-decompose-complex-ops-early) //----- //
func.func @forward(%arg0: !torch.vtensor<[?,?,?],f32>) -> (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) {
  %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
  %1 = torch.copy.to_tensor %0 : !torch.tensor
  %int2 = torch.constant.int 2
  %int9 = torch.constant.int 9
  %int6 = torch.constant.int 6
  %int8 = torch.constant.int 8
  %int0 = torch.constant.int 0
  %int1 = torch.constant.int 1
  %int3 = torch.constant.int 3
  %2 = torch.prim.ListConstruct %int9, %int6, %int8 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
  %3 = torch.aten.view %1, %2 : !torch.tensor, !torch.list<int> -> !torch.tensor
  %int1_0 = torch.constant.int 1
  %int0_1 = torch.constant.int 0
  %int2_2 = torch.constant.int 2
  %4 = torch.aten.slice.Tensor %3, %int2, %int0_1, %int2_2, %int1_0 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %int2_3 = torch.constant.int 2
  %int4 = torch.constant.int 4
  %5 = torch.aten.slice.Tensor %3, %int2, %int2_3, %int4, %int1_0 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %int4_4 = torch.constant.int 4
  %int6_5 = torch.constant.int 6
  %6 = torch.aten.slice.Tensor %3, %int2, %int4_4, %int6_5, %int1_0 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %int6_6 = torch.constant.int 6
  %int8_7 = torch.constant.int 8
  %7 = torch.aten.slice.Tensor %3, %int2, %int6_6, %int8_7, %int1_0 : !torch.tensor, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.tensor
  %8 = torch.prim.ListConstruct %4, %5, %6, %7 : (!torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<tensor>
  %9 = torch.aten.__getitem__.t %8, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
  %10 = torch.aten.__getitem__.t %8, %int0 : !torch.list<tensor>, !torch.int -> !torch.tensor
  %11 = torch.aten.__getitem__.t %8, %int1 : !torch.list<tensor>, !torch.int -> !torch.tensor
  %12 = torch.aten.__getitem__.t %8, %int2 : !torch.list<tensor>, !torch.int -> !torch.tensor
  %13 = torch.aten.__getitem__.t %8, %int3 : !torch.list<tensor>, !torch.int -> !torch.tensor
  %14 = torch.prim.TupleConstruct %10, %11, %12, %13 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor -> !torch.tuple<tensor, tensor, tensor, tensor>
  %int0_8 = torch.constant.int 0
  %15 = torch.prim.TupleIndex %14, %int0_8 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
  %int1_9 = torch.constant.int 1
  %16 = torch.prim.TupleIndex %14, %int1_9 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
  %int2_10 = torch.constant.int 2
  %17 = torch.prim.TupleIndex %14, %int2_10 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
  %int3_11 = torch.constant.int 3
  %18 = torch.prim.TupleIndex %14, %int3_11 : !torch.tuple<tensor, tensor, tensor, tensor>, !torch.int -> !torch.tensor
  return %15, %16, %17, %18 : !torch.tensor, !torch.tensor, !torch.tensor, !torch.tensor
}

Let me know if you need a different example.

qedawkins avatar Jul 29 '22 02:07 qedawkins

I think something that would work a bit better here is to look at the __getitem__ calls after the split and see what indices are gotten, and then assume that those are all the indices (i.e. the length of the list). Of course, we can't know that for sure, so we would insert a runtime assertion to guard that, which will hold in practice for this sort of cases and most others. That's still a little hacky but pretty good imo.

The real solution here is to run shape inference, maximizing value semantics, decompositions, dtype inference, inlining global slots, and canonicalization in a fixed-point iteration, but that's a much larger change.

silvasean avatar Jul 29 '22 17:07 silvasean

I think because there were upstream decompositions for some of these split ops added recently so there isn't an urgency to support list of tensor types. This was mainly a question to see what to do with #1056 because the solution presented there did not seem viable. As a result I would guess that your latter suggestion would be preferred as it feels like the 'correct' way to handle decompositions anyway (that's my thinking at least).

qedawkins avatar Jul 29 '22 18:07 qedawkins

https://github.com/llvm/torch-mlir/pull/1165 should help with this.

silvasean avatar Aug 05 '22 23:08 silvasean

Thanks! After that lands this issue can probably be closed.

qedawkins avatar Aug 05 '22 23:08 qedawkins

Closing since https://github.com/llvm/torch-mlir/pull/1165 completely fixes these catch-22's.

silvasean avatar Oct 07 '22 13:10 silvasean