torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

Some iteration-related operations fail with a confusing error message

Open dellis23 opened this issue 3 years ago • 1 comments

The following error occurs when I try to do a number of operations involving iteration:

torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: unsupported by backend contract: tensor with unknown rank
note: see current operation: %6 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<*,f32>
note: this is likely due to a missing shape transfer function in shape_lib_gen.py

Here's a script that reproduces various cases that cause the failure, and some that do not:

import functorch
import torch
import torch_mlir

from functorch._src.compile_utils import strip_overloads


# Fails
def tuple_stuff(w, grad_w):
    for i, grad in enumerate(grad_w):
        w[i] = w[i] + grad
    return w


# Fails
def tuple_stuff_2(w):
    for i, item in enumerate(w):
        return w[i]


# Fails
def tuple_stuff_3(w):
    for i in w:
        return i
        

# Fails
def tuple_stuff_4(w):
    for i in w:
        return torch.tensor(3.0) * i


# Passes
def tuple_stuff_5(w):
    size = len(w)
    for i in range(size):
        return w[i] * 2.


# Fails
def tuple_stuff_6(w):
    size = len(w)
    for i in range(size):
        w[i] = w[i] * 2.
    return w


def main():
    w = torch.tensor([1., 2., 3.])
    grad_w = torch.tensor([1., 1., 1.])
    graph = functorch.make_fx(tuple_stuff_6)(w)
    strip_overloads(graph)
    linalg_on_tensors_mlir = torch_mlir.compile(
        graph,
        (w,),
        output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
        use_tracing=False)


if __name__ == "__main__":
    main()

It's unclear to me exactly what conditions cause things to break.

I talked to Ramiro about this, and it looks like the following operation is the culprit:

    %24 = torch.operator "aten.unbind.int"(%23, %int0) : (!torch.tensor<*,f32>, !torch.int) -> !torch.list<tensor>

A few questions:

  • What exactly is the cause or causes of the failures?
  • Are they something we can fix or support?
  • If not, can we throw a better error message to communicate to the user what the problem is and how to fix it?

dellis23 avatar Oct 13 '22 15:10 dellis23

Spoke with Ramiro. We think we might be able to fold one of these cases down to a simple index to support this behavior.

dellis23 avatar Oct 13 '22 21:10 dellis23

tuple_stuff - expected to fail because it mutates input in-place (actually we are lucky that this fails at compile time -- in principle this program has undefined behavior). We should definitely diagnose this better -- I can't think of a way to do it reliably though. tuple_stuff_2 - succeeds on my machine tuple_stuff_3 - aten.unbind.int needs to be implemented. tuple_stuff_4 - same -- The root cause is the torch.operator "aten.unbind.int" which is blocking conversion to value semantics and shape inference, which results in an earlier op not being converted as well. We should definitely diagnose this better -- basically satisfiesBackendContract could be tweaked to change the op visitation order and visit torch.operator ops first (tweak the walk here). tuple_stuff_5 - works on my machine tuple_stuff_6 - expected to fail for same reason as tuple_stuff

silvasean avatar Oct 18 '22 14:10 silvasean

tuple_stuff_2 - succeeds on my machine

Whoops, messed up the comment. Passes for me too. Fixed.

aten.unbind.int needs to be implemented.

It sounds like this is the action item for this issue then? Any thoughts on Ramiro's alternative approach of rewriting the ops to just index directly?

dellis23 avatar Oct 18 '22 15:10 dellis23

It sounds like this is the action item for this issue then? Any thoughts on Ramiro's alternative approach of rewriting the ops to just index directly?

Adding the canonicalization pattern is one form of implementing the op, albeit a limited one. I don't think it's necessary to add something like a decomposition/lowering for aten.unbind.int. Adding a decomposition/lowering will probably require a lot of changes because of the fact that the op returns a list of tensors.

@silvasean, for reference, the op aten.unbind.int splits a tensor along a specified dimension into a list of tensors. From looking at the IR where the aten.unbind.int ops are being generated, they are always followed by a list indexing. In other words, An input[2] is being turned into torch.unbind(input, 0)[2]. So I suggested to Dan that this can be fixed with a canonicalization pattern on aten.__getitem__.t that when preceded by an aten.unbind.int replaces it with an aten.index.Tensor op.

ramiro050 avatar Oct 18 '22 16:10 ramiro050

Edit: Disregard everything below. I forgot to rebuild when trying to test whether the new change was causing the getitem to disappear or it never was present at all.

Alright, this is getting a bit hairy since I think there are a few cases floating around and I'm not sure which generate what. I attempted to rewrite the unbind / getitem combo here. This compiles and I even see it doing a rewrite. (Though Ramiro corrected me that I can't just use getitem on the tensor anyway, I need to swap it with AtenIndexTensorOp.)

The problem is, when I run one of the the repros that Sean says we should be able to handle (e.g. tuple_stuff_3), I end up seeing the following error and IR:

<unknown>:0: error: unsupported by backend contract: tensor with unknown rank
<unknown>:0: note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<*,f32>
<unknown>:0: note: this is likely due to a missing shape transfer function in shape_lib_gen.py
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) //----- //
module attributes {torch.debug_module_name = "tuple_stuff_3"} {
  func.func @forward(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<*,f32> {
    %int0 = torch.constant.int 0
    %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[3],f32> to !torch.vtensor<*,f32>
    %1 = torch.copy.to_tensor %0 : !torch.tensor<*,f32>
    %2 = torch.aten.unbind.int %1, %int0 : !torch.tensor<*,f32>, !torch.int -> !torch.list<tensor>
    %3 = torch.copy.to_vtensor %1 : !torch.vtensor<*,f32>
    return %3 : !torch.vtensor<*,f32>
  }
}

So in this case, there's no getitem. We're just unbinding, and I don't actually see us even using the value. But I do see us using the shapeless static cast result, which presumably is causing the failure.

I'll have to try to figure out how to reproduce the original case of the unbind followed by the getitem. I'm not sure what under the hood is generating the one that we are seeing right now, since it doesn't seem to be in my forward function. But the bigger question now is what approach to take to solve this one.

dellis23 avatar Oct 18 '22 21:10 dellis23

@silvasean, for reference, the op aten.unbind.int splits a tensor along a specified dimension into a list of tensors. From looking at the IR where the aten.unbind.int ops are being generated, they are always followed by a list indexing. In other words, An input[2] is being turned into torch.unbind(input, 0)[2]. So I suggested to Dan that this can be fixed with a canonicalization pattern on aten.__getitem__.t that when preceded by an aten.unbind.int replaces it with an aten.index.Tensor op.

That makes sense to me.

silvasean avatar Oct 19 '22 09:10 silvasean