torch-mlir
torch-mlir copied to clipboard
linalg fails to reshape in some cases
Through ClassAnnotator.annotateArgs, I set x to be a matrix of ?x2, and I want to reshape it to be 2x?.
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return x.reshape([x.shape[1], -1])
⬇️ torchscript-module-to-torch-backend-pipeline
module attributes {torch.debug_module_name = "Model"} {
func.func @forward(%arg0: !torch.vtensor<[?,2],f32>) -> !torch.vtensor<[2,?],f32> {
%int2 = torch.constant.int 2
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int2, %int-1 : (!torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,2],f32>, !torch.list<int> -> !torch.vtensor<[2,?],f32>
return %1 : !torch.vtensor<[2,?],f32>
}
}
⬇️ torch-backend-to-linalg-on-tensors-backend-pipeline failed ❌
/tmp/r.torch.mlir:6:10: error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[?,2],f32>, !torch.list<int> -> !torch.vtensor<[2,?],f32>
^
/tmp/r.torch.mlir:6:10: note: see current operation: %3 = "torch.aten.view"(%arg0, %2) : (!torch.vtensor<[?,2],f32>, !torch.list<int>) -> !torch.vtensor<[2,?],f32>