[TOSA] failure compiling model dividing tensor by constant (for the trace method)
The following python script fails:
import torch
import torch_mlir
class Model(torch.nn.Module):
def forward(self, x):
return x / 2.0
model = Model()
model.eval()
test_input = torch.rand((1,))
module = torch_mlir.compile(
model, [test_input],
output_type=torch_mlir.OutputType.TOSA,
use_tracing=True)
module.operation.print()
Error Message:
/tmp/Model.mlir:
#loc = loc(unknown)
module attributes {torch.debug_module_name = "Model"} {
func.func @forward(%arg0: !torch.vtensor<[1],f32> loc(unknown)) -> !torch.vtensor<[1],f32> {
%0 = torch.vtensor.literal(dense<2.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64> loc(#loc1)
%1 = torch.aten.div.Tensor %arg0, %0 : !torch.vtensor<[1],f32>, !torch.vtensor<[],f64> -> !torch.vtensor<[1],f32> loc(#loc1)
return %1 : !torch.vtensor<[1],f32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("/tmp/torch-mlir-tosa-bug.py":7:0)
Output of torch-mlir-opt -mlir-print-ir-after-all -mlir-disable-threading -pass-pipeline='builtin.module(torch-backend-to-tosa-backend-pipeline)' /tmp/Model.mlir:
/tmp/torch-mlir-tosa-bug.py:7:0: error: 'tosa.reciprocal' op operand #0 must be tensor of number values, but got 'tensor<f64>'
/tmp/torch-mlir-tosa-bug.py:7:0: note: see current operation: %2 = "tosa.reciprocal"(%1) : (tensor<f64>) -> tensor<f64>
// -----// IR Dump After ConvertTorchToTosa Failed (convert-torch-to-tosa) //----- //
"func.func"() <{function_type = (!torch.vtensor<[1],f32>) -> !torch.vtensor<[1],f32>, sym_name = "forward"}> ({
^bb0(%arg0: !torch.vtensor<[1],f32>):
%0 = "torch_c.to_builtin_tensor"(%arg0) : (!torch.vtensor<[1],f32>) -> tensor<1xf32>
%1 = "tosa.const"() <{value = dense<2.000000e+00> : tensor<f64>}> : () -> tensor<f64>
%2 = "tosa.reciprocal"(%1) : (tensor<f64>) -> tensor<f64>
%3 = "tosa.cast"(%2) : (tensor<f64>) -> tensor<f32>
%4 = "tosa.mul"(%0, %3) <{shift = 0 : i32}> : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
%5 = "torch_c.from_builtin_tensor"(%4) : (tensor<1xf32>) -> !torch.vtensor<[1],f32>
"func.return"(%5) : (!torch.vtensor<[1],f32>) -> ()
}) : () -> ()
Tested with commit b9d29dc.
@eric-k256, do you know what's going on here? The IR seems fine to me. Maybe reciprocal does not work for zero-rank tensors.
The tosa IR verification fails, because tosa.reciprocal and all other tosa-operations, except tosa.cast and tosa.const do not allow f64-tensors. See mlir/Dialect/Tosa/IR/TosaTypesBase.td.
@tbartsch-semron Hi tbartsch-semron, I have also met this failure. Could you please tell me how you solve this problem, thanks so much !
you could consider to rewrite this operator into a TOSA fp32 perhaps with a custom pass