`torch.prim.If.yield` and `torch.tensor_static_info_cast` ops don't seem to get along
torch.prim.If.yield and torch.tensor_static_info_cast op don't seem to get along.
To replicate the error
Code
import torch
class MyIfModel(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
y = x
else:
y = -x
return y
import torch_mlir
test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
[test_input],
output_type="torch",
enable_ir_printing=True)
Error
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) ('builtin.module' operation) //----- //
module attributes {torch.debug_module_name = "MyIfModel"} {
func.func @forward(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor {
%int0 = torch.constant.int 0
%none = torch.constant.none
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[15],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.copy.to_vtensor %1 : !torch.vtensor
%3 = torch.aten.sum %2, %none : !torch.vtensor, !torch.none -> !torch.vtensor<[],unk>
%4 = torch.aten.gt.Scalar %3, %int0 : !torch.vtensor<[],unk>, !torch.int -> !torch.vtensor<[],i1>
%5 = torch.aten.Bool.Tensor %4 : !torch.vtensor<[],i1> -> !torch.bool
%6 = torch.prim.If %5 -> (!torch.tensor) {
torch.prim.If.yield %1 : !torch.tensor
} else {
%8 = torch.copy.to_vtensor %1 : !torch.vtensor
%9 = torch.aten.neg %8 : !torch.vtensor -> !torch.vtensor<[15],unk>
%10 = torch.tensor_static_info_cast %9 : !torch.vtensor<[15],unk> to !torch.vtensor
%11 = torch.copy.to_tensor %10 : !torch.tensor
torch.prim.If.yield %11 : !torch.tensor
}
%7 = torch.copy.to_vtensor %6 : !torch.vtensor
return %7 : !torch.vtensor
}
}
Traceback (most recent call last):
File "/home/azureuser/torch-mlir/scratch/replicate_static_info_cast_error.py", line 15, in <module>
torch_mlir.compile(MyIfModel(),
File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 458, in compile
run_pipeline_with_repro_report(
File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: tensor with unknown rank
note: unknown: see current operation: %2 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[15],f32>) -> !torch.vtensor
note: unknown: this is likely due to a missing transfer function in abstract_interp_lib_gen.py
For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops= extra-library=})' /tmp/MyIfModel.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
Using just an if op and yielding a constant causes no error
import torch
class MyIfModel(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
y = 3
else:
y = -3
return y
import torch_mlir
test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
[test_input],
output_type="torch",
enable_ir_printing=True)
Returning x without using if op also causes no error
import torch
class MyIfModel(torch.nn.Module):
def forward(self, x):
return -x
import torch_mlir
test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
[test_input],
output_type="torch",
enable_ir_printing=True)
@silvasean because you're the author of both Ancient Ops of Power (prim if and tensor_static_info_cast)
I doubt it is worth getting this working through the legacy TorchScript path. Focus on either FX or onnx for testing
@stellaraccident I encountered this as I was trying to get a small reference MLIR for the "if" op.
How can I do this without using the legacy TorchScript path?
Looks like this is a non-issue caused by not having shape for one of my types because i was using the legacy torchscript path.
Here's an example of how to do it correctly
import torch
import shark_turbine
class IndexPutModule(torch.nn.Module):
def __init__(self):
super(IndexPutModule, self).__init__()
def forward(self):
condition = torch.tensor([True])
value_a = 0
value_b = 15
return torch.ops.aten.where(condition, value_a, value_b)
module = IndexPutModule()
import shark_turbine.aot as aot
export_output = aot.export(module)
export_output.print_readable()