torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

More TorchToLinalg view failure cases

Open dan-garvey opened this issue 1 year ago • 11 comments

    %114 = torch.aten.view %112, %113 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
<stdin>:929:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
    %7 = torch.aten.view %5, %6 : !torch.vtensor<[1,?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>  

For the latter, I know @ramiro050 landed https://github.com/huggingface/transformers/pull/26059 to avoid this issue for many models (the views are often no-ops).

While optimizing model developer code is great, eternally patching huggingface models seems worse than supporting the lowering and then having a pass remove it as a no-op.

dan-garvey avatar Feb 07 '24 02:02 dan-garvey

<stdin>:929:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
    %7 = torch.aten.view %5, %6 : !torch.vtensor<[1,?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>

should this be handled by torch.aten.view folder (it's identity) ?

  %114 = torch.aten.view %112, %113 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>

This possibly started life as an unsqueeze. If it did, we could lower to the torch-mlir unsqueeze op, which lowers directly to linalg.expand_shape. But I guess it depends where this op came from, whether we can deduce this. This https://github.com/llvm/llvm-project/pull/69267 might be helpful too (If it was a reshape from (a,b) -> (1,1,a/4, b*4) say).

newling avatar Feb 07 '24 17:02 newling

it started its life as view(1, 1, key_size, key_size), I actually circumvented the error by rewriting it with unsqueezes. I suppose the problem is with double dynamic dims like that is that you don't know if they are the same or not unless you're working with the strict symbolic shapes?

dan-garvey avatar Feb 07 '24 18:02 dan-garvey

it started its life as view(1, 1, key_size, key_size),

Do you have a link to the PyTorch model definition where they do this view?

I suppose the problem is with double dynamic dims like that is that you don't know if they are the same or not unless you're working with the strict symbolic shapes?

Yeah, this is a very annoying limitation. Since we don't have symbolic shapes, we cannot know if two dimensions of a tensor are the same without symbolically analyzing the op that produced the tensor (and potentially the ops before it). The patch referenced (https://github.com/llvm/llvm-project/pull/69267) would likely fix this, or using the FX importer path into Torch-MLIR.

ramiro050 avatar Feb 08 '24 00:02 ramiro050

I am actually coming via fx_importer, but I dont think the view lowering takes the strict_symbolic_shapes into account?

here's a link:

https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L505

the other offending views are the same as from here: https://github.com/huggingface/transformers/pull/26059

dan-garvey avatar Feb 08 '24 00:02 dan-garvey

I am actually coming via fx_importer, but I dont think the view lowering takes the strict_symbolic_shapes into account?

Oh, then maybe we can fix this by modifying the importer. The importer should have access to the symbolic shapes that PyTorch generates. Once we import, all the symbolic shape information is lost and retrieving can become very painful.

here's a link:

https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L505

the other offending views are the same as from here: huggingface/transformers#26059

Yeah, this is very annoying to fix once we are inside Torch-MLIR. One can definitely match on the IR, but it is hard to make a pattern general enough to be useful for slightly different situations. For example, in the issue I was facing, I could've solved it as follows:

  1. match on view op
  2. looking at the defining op of the operand of view
  3. Make sure defining op is a torch.rand
  4. Use the arguments of torch.rand to get the symbolic shape of the view operand
  5. Compare the symbolic shape with the shape list argument of view
  6. Handle the case where the conclusion is that this view is just an unsqueeze (or identity)

However, this pattern is both a bit complex and so fragile that it would not fix your issue. In your case, you have to do a similar thing but using instead torch.ones to get the symbolic shape, and you also need to know that any ops in between do nothing to the shape of the tensor. In this case, you need to know that torch.tril only modifies the contents. It becomes a game of whack-a-mole.

I think we should explore making the importer a bit smarter first, and if that does not work, we can do an explicit pattern rewrite.

ramiro050 avatar Feb 08 '24 01:02 ramiro050

To reproduce this error

import torch
from torch_mlir import torchscript

class OnesTrilViewModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        key_size = x.shape[0]
        return torch.tril(torch.ones((key_size, key_size), dtype=torch.bool)).view(
            1, 1, key_size, key_size
        )

tanh_example_input = torch.empty(
    5,
)

placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[0])
out = torchscript.compile(OnesTrilViewModule(), placeholder)

print(out)

with open("torch.mlir", "w") as f:
    f.write(str(out))

to get the IR

module attributes {torch.debug_module_name = "OnesTrilViewModule"} {
  func.func @forward(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,1,?,?],i1> {
    %int1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int11 = torch.constant.int 11
    %none = torch.constant.none
    %0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
    %1 = torch.prim.ListConstruct %0, %0 : (!torch.int, !torch.int) -> !torch.list<int>
    %2 = torch.aten.ones %1, %int11, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],i1>
    %3 = torch.aten.tril %2, %int0 : !torch.vtensor<[?,?],i1>, !torch.int -> !torch.vtensor<[?,?],i1>
    %4 = torch.prim.ListConstruct %int1, %int1, %0, %0 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %5 = torch.aten.view %3, %4 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
    return %5 : !torch.vtensor<[1,1,?,?],i1>
  }
}

Compiling through linalg (I used iree-compile torch.mlir --iree-hal-target-backends=llvm-cpu --compile-to=flow but it could be anything) you get

torch.mlir:12:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
    %5 = torch.aten.view %3, %4 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
         ^
torch.mlir:12:10: note: see current operation: %36 = "torch.aten.view"(%34, %35) : (!torch.vtensor<[?,?],i1>, !torch.list<int>) -> !torch.vtensor<[1,1,?,?],i1>

newling avatar Feb 08 '24 13:02 newling

I'm imagining a utility function which attempts to get a symbolic shape for any value.

FailureOr<ListConstruct> getSymbolicShape(Value v);

The implementation recursively walks through producers until it has something to return, or fails. Ops like Tril would contribute logic that says the result has the same shape as the operand. Ops like Ones and View would return their result shapes directly.

In a View canonicalizer, you would then call getSymbolicShape on the operand, and compare the shape to results shape to determine if it can be converted into an Unsqueeze.

@ramiro050 I'm curious what "making the importer a bit smarter" might look like (I know very little about the importer!)

newling avatar Feb 08 '24 13:02 newling

@dan-garvey, are you using a similar import process as the reproducer? I think you mentioned you were using the fx importer rather than the torchscript one.

I'm curious what "making the importer a bit smarter" might look like (I know very little about the importer!)

When PyTorch generates an FX graph using Dynamo, (my understanding is that) all the shapes for the tensors in the graph are symbolic shapes. This means that when we are about to import the view op into torch-mlir, we can actually check the shape argument and see which shape dimension sizes match the dimension sizes of the input tensor. Then we can replace those with torch.size ops that get the sizes from the input tensor. For example, the case Dan is running into above

causal_mask = torch.tril(
  torch.ones((key_size, key_size), dtype=torch.bool, device=query.device))
causal_mask = causal_mask.view(1, 1, key_size, key_size)

would become

causal_mask = torch.tril(
  torch.ones((key_size, key_size), dtype=torch.bool, device=query.device))
-causal_mask = causal_mask.view(1, 1, key_size, key_size)
+causal_mask = causal_mask.view(1, 1, causal_mask.size(0), causal_mask.size(1))

This type of transformation can likely be written as an FX transformation in Python before we ingest the graph into Torch-MLIR. I think we should take advantage of the fact that the symbolic shape information is there rather than try to recreate it ourselves.

ramiro050 avatar Feb 08 '24 16:02 ramiro050

Also, the importer I'm talking about is: https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py

Here's an example for how to use it:

https://github.com/llvm/torch-mlir/blob/main/test/python/fx_importer/basic_test.py

ramiro050 avatar Feb 08 '24 16:02 ramiro050

I haven't played around with fx transformations, but I completely agree that fundamentally we should take advantage of having access to the symbolic dims.

given what you propose I wonder if it wouldn't be better to represent all the dyanmic dims as the output of a size op? I wonder how hard it would be to fuse all the redundant size ops this would create

dan-garvey avatar Feb 12 '24 05:02 dan-garvey

given what you propose I wonder if it wouldn't be better to represent all the dyanmic dims as the output of a size op? I wonder how hard it would be to fuse all the redundant size ops this would create

The challenge is representing all dynamic dims as the output of a size op applied to the input tensor of the view op. Our conversion pattern for aten.view only knows how to deal with such size ops. And to achieve this for the failure case here you need symbolic shape information.

ramiro050 avatar Feb 14 '24 00:02 ramiro050