torch-mlir icon indicating copy to clipboard operation
torch-mlir copied to clipboard

`torch.prim.If.yield` and `torch.tensor_static_info_cast` ops don't seem to get along

Open renxida opened this issue 1 year ago • 3 comments

torch.prim.If.yield and torch.tensor_static_info_cast op don't seem to get along.

To replicate the error

Code

import torch

class MyIfModel(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            y = x
        else:
            y = -x
        return y


import torch_mlir

test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
                    [test_input],
                    output_type="torch",
                    enable_ir_printing=True)

Error

// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) ('builtin.module' operation) //----- //
module attributes {torch.debug_module_name = "MyIfModel"} {
  func.func @forward(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor {
    %int0 = torch.constant.int 0
    %none = torch.constant.none
    %0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[15],f32> to !torch.vtensor
    %1 = torch.copy.to_tensor %0 : !torch.tensor
    %2 = torch.copy.to_vtensor %1 : !torch.vtensor
    %3 = torch.aten.sum %2, %none : !torch.vtensor, !torch.none -> !torch.vtensor<[],unk>
    %4 = torch.aten.gt.Scalar %3, %int0 : !torch.vtensor<[],unk>, !torch.int -> !torch.vtensor<[],i1>
    %5 = torch.aten.Bool.Tensor %4 : !torch.vtensor<[],i1> -> !torch.bool
    %6 = torch.prim.If %5 -> (!torch.tensor) {
      torch.prim.If.yield %1 : !torch.tensor
    } else {
      %8 = torch.copy.to_vtensor %1 : !torch.vtensor
      %9 = torch.aten.neg %8 : !torch.vtensor -> !torch.vtensor<[15],unk>
      %10 = torch.tensor_static_info_cast %9 : !torch.vtensor<[15],unk> to !torch.vtensor
      %11 = torch.copy.to_tensor %10 : !torch.tensor
      torch.prim.If.yield %11 : !torch.tensor
    }
    %7 = torch.copy.to_vtensor %6 : !torch.vtensor
    return %7 : !torch.vtensor
  }
}


Traceback (most recent call last):
  File "/home/azureuser/torch-mlir/scratch/replicate_static_info_cast_error.py", line 15, in <module>
    torch_mlir.compile(MyIfModel(),
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/__init__.py", line 458, in compile
    run_pipeline_with_repro_report(
  File "/home/azureuser/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir/compiler_utils.py", line 73, in run_pipeline_with_repro_report
    raise TorchMlirCompilerError(trimmed_message) from None
torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:


python exception: Failure while executing pass pipeline:
error: unknown: unsupported by backend contract: tensor with unknown rank
note: unknown: see current operation: %2 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[15],f32>) -> !torch.vtensor
note: unknown: this is likely due to a missing transfer function in abstract_interp_lib_gen.py

For Torch-MLIR developers, the error can be reproduced with:
$ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops= extra-library=})' /tmp/MyIfModel.mlir
Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.

Using just an if op and yielding a constant causes no error

import torch

class MyIfModel(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            y = 3
        else:
            y = -3
        return y


import torch_mlir

test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
                    [test_input],
                    output_type="torch",
                    enable_ir_printing=True)

Returning x without using if op also causes no error

import torch

class MyIfModel(torch.nn.Module):
    def forward(self, x):
        return -x


import torch_mlir

test_input = torch.rand(15)
torch_mlir.compile(MyIfModel(),
                    [test_input],
                    output_type="torch",
                    enable_ir_printing=True)

renxida avatar Jan 30 '24 01:01 renxida

@silvasean because you're the author of both Ancient Ops of Power (prim if and tensor_static_info_cast)

renxida avatar Jan 30 '24 01:01 renxida

I doubt it is worth getting this working through the legacy TorchScript path. Focus on either FX or onnx for testing

stellaraccident avatar Jan 30 '24 04:01 stellaraccident

@stellaraccident I encountered this as I was trying to get a small reference MLIR for the "if" op.

How can I do this without using the legacy TorchScript path?

renxida avatar Jan 31 '24 17:01 renxida

Looks like this is a non-issue caused by not having shape for one of my types because i was using the legacy torchscript path.

Here's an example of how to do it correctly

import torch
import shark_turbine



class IndexPutModule(torch.nn.Module):
    def __init__(self):
        super(IndexPutModule, self).__init__()
    def forward(self):
        condition = torch.tensor([True])
        value_a = 0
        value_b = 15
        

        return torch.ops.aten.where(condition, value_a, value_b)


module = IndexPutModule()
import shark_turbine.aot as aot

export_output = aot.export(module)
export_output.print_readable()

renxida avatar Mar 26 '24 16:03 renxida