torch-mlir
torch-mlir copied to clipboard
When compile custom_op, got error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
test code:
import pathlib
import torch
from torch_mlir import compiler_utils, torchscript
import numpy as np
@torch.library.custom_op("mylib::numpy_sin", mutates_args=[], schema='(Tensor x) -> Tensor')
def numpy_sin(input_tensor: torch.Tensor) -> torch.Tensor:
input_np = input_tensor.numpy()
output_np = np.zeros_like(input_np)
np.sin(input_np, out=output_np)
return torch.from_numpy(output_np).to(device=input_tensor.device)
class NumpySinModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
return numpy_sin(a)
def test_custom_numpy_sin(tmp_path):
inputs = [torch.ones(128, 128, dtype=torch.float)]
model = NumpySinModule().eval()
compile_type = compiler_utils.OutputType.TORCH
result = torchscript.compile(
model, inputs[0], output_type=compile_type, use_tracing=True, enable_ir_printing=True
)
assert torch.allclose(result, inputs[0].sin())
if __name__ == '__main__':
test_custom_numpy_sin(pathlib.Path('.'))
bottom of error:
> raise TorchMlirCompilerError(trimmed_message) from None
E torch_mlir.compiler_utils.TorchMlirCompilerError: Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
E error: failed to legalize operation 'torch.operator' that was explicitly marked illegal
E note: see current operation: %3 = "torch.operator"(%2) <{name = "mylib.numpy_sin"}> : (!torch.tensor<[128,128],f32>) -> !torch.tensor<[128,128],f32>
E
E
E python exception: Failure while executing pass pipeline
E
E For Torch-MLIR developers, the error can be reproduced with:
E $ torch-mlir-opt -pass-pipeline='builtin.module(torchscript-module-to-torch-backend-pipeline{backend-legal-ops= extra-library=})' /tmp/NumpySinModule.mlir
E Add '-mlir-print-ir-after-all -mlir-disable-threading' to get the IR dump for debugging purpose.
last dumped IR:
// -----// IR Dump Before ReduceOpVariants (torch-reduce-op-variants) ('func.func' operation: @forward) //----- //
module attributes {torch.debug_module_name = "NumpySinModule"} {
func.func @forward(%arg0: !torch.vtensor<[128,128],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[128,128],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[128,128],f32>
%3 = torch.operator "mylib.numpy_sin"(%2) : (!torch.tensor<[128,128],f32>) -> !torch.tensor<[128,128],f32>
%4 = torch.tensor_static_info_cast %3 : !torch.tensor<[128,128],f32> to !torch.tensor
return %4 : !torch.tensor
}
}
// -----// IR Dump After ReduceOpVariants Failed (torch-reduce-op-variants) ('func.func' operation: @forward) //----- //
module attributes {torch.debug_module_name = "NumpySinModule"} {
func.func @forward(%arg0: !torch.vtensor<[128,128],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[128,128],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[128,128],f32>
%3 = torch.operator "mylib.numpy_sin"(%2) : (!torch.tensor<[128,128],f32>) -> !torch.tensor<[128,128],f32>
%4 = torch.tensor_static_info_cast %3 : !torch.tensor<[128,128],f32> to !torch.tensor
return %4 : !torch.tensor
}
}
// -----// IR Dump After LowerToBackendContract Failed (torch-lower-to-backend-contract) ('builtin.module' operation) //----- //
module attributes {torch.debug_module_name = "NumpySinModule"} {
func.func @forward(%arg0: !torch.vtensor<[128,128],f32>) -> !torch.tensor {
%0 = torch.tensor_static_info_cast %arg0 : !torch.vtensor<[128,128],f32> to !torch.vtensor
%1 = torch.copy.to_tensor %0 : !torch.tensor
%2 = torch.tensor_static_info_cast %1 : !torch.tensor to !torch.tensor<[128,128],f32>
%3 = torch.operator "mylib.numpy_sin"(%2) : (!torch.tensor<[128,128],f32>) -> !torch.tensor<[128,128],f32>
%4 = torch.tensor_static_info_cast %3 : !torch.tensor<[128,128],f32> to !torch.tensor
return %4 : !torch.tensor
}
}
Hi @0x00-pl
I encountered the same error.
Could you share how you performed custom Op lowering through ReduceOpVariants ?
Thank you!