More TorchToLinalg view failure cases
%114 = torch.aten.view %112, %113 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
<stdin>:929:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
%7 = torch.aten.view %5, %6 : !torch.vtensor<[1,?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
For the latter, I know @ramiro050 landed https://github.com/huggingface/transformers/pull/26059 to avoid this issue for many models (the views are often no-ops).
While optimizing model developer code is great, eternally patching huggingface models seems worse than supporting the lowering and then having a pass remove it as a no-op.
<stdin>:929:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
%7 = torch.aten.view %5, %6 : !torch.vtensor<[1,?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
should this be handled by torch.aten.view folder (it's identity) ?
%114 = torch.aten.view %112, %113 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
This possibly started life as an unsqueeze. If it did, we could lower to the torch-mlir unsqueeze op, which lowers directly to linalg.expand_shape. But I guess it depends where this op came from, whether we can deduce this. This https://github.com/llvm/llvm-project/pull/69267 might be helpful too (If it was a reshape from (a,b) -> (1,1,a/4, b*4) say).
it started its life as view(1, 1, key_size, key_size), I actually circumvented the error by rewriting it with unsqueezes. I suppose the problem is with double dynamic dims like that is that you don't know if they are the same or not unless you're working with the strict symbolic shapes?
it started its life as
view(1, 1, key_size, key_size),
Do you have a link to the PyTorch model definition where they do this view?
I suppose the problem is with double dynamic dims like that is that you don't know if they are the same or not unless you're working with the strict symbolic shapes?
Yeah, this is a very annoying limitation. Since we don't have symbolic shapes, we cannot know if two dimensions of a tensor are the same without symbolically analyzing the op that produced the tensor (and potentially the ops before it). The patch referenced (https://github.com/llvm/llvm-project/pull/69267) would likely fix this, or using the FX importer path into Torch-MLIR.
I am actually coming via fx_importer, but I dont think the view lowering takes the strict_symbolic_shapes into account?
here's a link:
https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L505
the other offending views are the same as from here: https://github.com/huggingface/transformers/pull/26059
I am actually coming via fx_importer, but I dont think the view lowering takes the strict_symbolic_shapes into account?
Oh, then maybe we can fix this by modifying the importer. The importer should have access to the symbolic shapes that PyTorch generates. Once we import, all the symbolic shape information is lost and retrieving can become very painful.
here's a link:
https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L505
the other offending views are the same as from here: huggingface/transformers#26059
Yeah, this is very annoying to fix once we are inside Torch-MLIR. One can definitely match on the IR, but it is hard to make a pattern general enough to be useful for slightly different situations. For example, in the issue I was facing, I could've solved it as follows:
- match on
viewop - looking at the defining op of the operand of
view - Make sure defining op is a
torch.rand - Use the arguments of
torch.randto get the symbolic shape of theviewoperand - Compare the symbolic shape with the
shapelist argument ofview - Handle the case where the conclusion is that this
viewis just anunsqueeze(or identity)
However, this pattern is both a bit complex and so fragile that it would not fix your issue. In your case, you have to do a similar thing but using instead torch.ones to get the symbolic shape, and you also need to know that any ops in between do nothing to the shape of the tensor. In this case, you need to know that torch.tril only modifies the contents. It becomes a game of whack-a-mole.
I think we should explore making the importer a bit smarter first, and if that does not work, we can do an explicit pattern rewrite.
To reproduce this error
import torch
from torch_mlir import torchscript
class OnesTrilViewModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
key_size = x.shape[0]
return torch.tril(torch.ones((key_size, key_size), dtype=torch.bool)).view(
1, 1, key_size, key_size
)
tanh_example_input = torch.empty(
5,
)
placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[0])
out = torchscript.compile(OnesTrilViewModule(), placeholder)
print(out)
with open("torch.mlir", "w") as f:
f.write(str(out))
to get the IR
module attributes {torch.debug_module_name = "OnesTrilViewModule"} {
func.func @forward(%arg0: !torch.vtensor<[?],f32>) -> !torch.vtensor<[1,1,?,?],i1> {
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int11 = torch.constant.int 11
%none = torch.constant.none
%0 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?],f32>, !torch.int -> !torch.int
%1 = torch.prim.ListConstruct %0, %0 : (!torch.int, !torch.int) -> !torch.list<int>
%2 = torch.aten.ones %1, %int11, %none, %none, %none : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],i1>
%3 = torch.aten.tril %2, %int0 : !torch.vtensor<[?,?],i1>, !torch.int -> !torch.vtensor<[?,?],i1>
%4 = torch.prim.ListConstruct %int1, %int1, %0, %0 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.view %3, %4 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
return %5 : !torch.vtensor<[1,1,?,?],i1>
}
}
Compiling through linalg (I used iree-compile torch.mlir --iree-hal-target-backends=llvm-cpu --compile-to=flow but it could be anything) you get
torch.mlir:12:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
%5 = torch.aten.view %3, %4 : !torch.vtensor<[?,?],i1>, !torch.list<int> -> !torch.vtensor<[1,1,?,?],i1>
^
torch.mlir:12:10: note: see current operation: %36 = "torch.aten.view"(%34, %35) : (!torch.vtensor<[?,?],i1>, !torch.list<int>) -> !torch.vtensor<[1,1,?,?],i1>
I'm imagining a utility function which attempts to get a symbolic shape for any value.
FailureOr<ListConstruct> getSymbolicShape(Value v);
The implementation recursively walks through producers until it has something to return, or fails. Ops like Tril would contribute logic that says the result has the same shape as the operand. Ops like Ones and View would return their result shapes directly.
In a View canonicalizer, you would then call getSymbolicShape on the operand, and compare the shape to results shape to determine if it can be converted into an Unsqueeze.
@ramiro050 I'm curious what "making the importer a bit smarter" might look like (I know very little about the importer!)
@dan-garvey, are you using a similar import process as the reproducer? I think you mentioned you were using the fx importer rather than the torchscript one.
I'm curious what "making the importer a bit smarter" might look like (I know very little about the importer!)
When PyTorch generates an FX graph using Dynamo, (my understanding is that) all the shapes for the tensors in the graph are symbolic shapes. This means that when we are about to import the view op into torch-mlir, we can actually check the shape argument and see which shape dimension sizes match the dimension sizes of the input tensor. Then we can replace those with torch.size ops that get the sizes from the input tensor. For example, the case Dan is running into above
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device))
causal_mask = causal_mask.view(1, 1, key_size, key_size)
would become
causal_mask = torch.tril(
torch.ones((key_size, key_size), dtype=torch.bool, device=query.device))
-causal_mask = causal_mask.view(1, 1, key_size, key_size)
+causal_mask = causal_mask.view(1, 1, causal_mask.size(0), causal_mask.size(1))
This type of transformation can likely be written as an FX transformation in Python before we ingest the graph into Torch-MLIR. I think we should take advantage of the fact that the symbolic shape information is there rather than try to recreate it ourselves.
Also, the importer I'm talking about is: https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/extras/fx_importer.py
Here's an example for how to use it:
https://github.com/llvm/torch-mlir/blob/main/test/python/fx_importer/basic_test.py
I haven't played around with fx transformations, but I completely agree that fundamentally we should take advantage of having access to the symbolic dims.
given what you propose I wonder if it wouldn't be better to represent all the dyanmic dims as the output of a size op? I wonder how hard it would be to fuse all the redundant size ops this would create
given what you propose I wonder if it wouldn't be better to represent all the dyanmic dims as the output of a size op? I wonder how hard it would be to fuse all the redundant size ops this would create
The challenge is representing all dynamic dims as the output of a size op applied to the input tensor of the view op. Our conversion pattern for aten.view only knows how to deal with such size ops. And to achieve this for the failure case here you need symbolic shape information.