`AtenViewOp` lowers incorrectly for static shaped input and `-1` dim
module attributes {torch.debug_module_name = "test"} {
func.func @test(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[32],f32>}) -> !torch.tensor {
%int-1 = torch.constant.int -1
%0 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
%1 = torch.aten._unsafe_view %arg0, %0 : !torch.tensor, !torch.list<int> -> !torch.tensor
return %1 : !torch.tensor
}
}
lowers to
module attributes {torch.debug_module_name = "test"} {
func.func @test(%arg0: tensor<32xf32>) -> tensor<?xf32> {
%c32_i64 = arith.constant 32 : i64
%c-1_i64 = arith.constant -1 : i64
%0 = arith.cmpi eq, %c32_i64, %c-1_i64 : i64
cf.assert %0, "mismatching contracting dimension"
%1 = tensor.cast %arg0 : tensor<32xf32> to tensor<?xf32>
return %1 : tensor<?xf32>
}
}
the %0 = arith.cmpi eq, %c32_i64, %c-1_i64 : i64 is spurious/incorrect and naturally cause the subsequent assert to fail. The root of the issue is the call to .. at checkDimEqualHelper https://github.com/llvm/torch-mlir/blob/24e04d5729d251c9f0cf0df85308de43a7b646a3/lib/Conversion/TorchToLinalg/DataMovement.cpp#L155-L160 which could/should be gated by something like inputRank == resultRank && resultType.hasStaticShape(). Happy to submit a PR with the "right" fix (if just adding that gate isn't the right fix).
For context I get into this in eager mode for e.g. UnsafeView1DFoldModule_basic where I know/annotate the static shape of the input (as opposed to the -1 symbolic dim annotation in the test as is).
cc @silvasean @cathyzhyi
@qedawkins can you confirm whether this is covered by our current e2e tests / lowering?
It looks like this was inadvertently fixed by #1082, however there is no test for this. Do we want to add one?
Yes, let's add a test for it and close this. Thanks!