[ONNX] Bernoulli operator implementation might be wrong (value mismatch in e2e testing when using function expansion)
Issue I discovered while working on #3384.
Tested on main, commit 0b46d1110aa9710a4c2935723c47dfe3d5c21fd3.
Normally the ONNX Bernoulli operator gets imported as torch.operator "onnx.Bernoulli", but since ONNX provides a function for this operator, it's possible to make the importer pre-expand it (using the code added in https://github.com/llvm/torch-mlir/pull/3409).
If we apply this patch
diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py
index 9fe29212..396096da 100644
--- a/python/torch_mlir/extras/onnx_importer.py
+++ b/python/torch_mlir/extras/onnx_importer.py
@@ -104,6 +104,7 @@ class Config:
# Default domain (ONNX built-in ops)
"": {
"MeanVarianceNormalization",
+ "Bernoulli",
}
}
)
then the new importer output becomes
module {
func.func @main_graph(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} {
%none = torch.constant.none
%0 = call @"('Bernoulli', '', 20, [tensor_type {\0A elem_type: 11\0A shape {\0A dim {\0A dim_param: \22dim_0_0\22\0A }\0A dim {\0A dim_param: \22dim_0_1\22\0A }\0A }\0A}\0A], [tensor_type {\0A elem_type: 11\0A shape {\0A dim {\0A dim_param: \22dim_0_0\22\0A }\0A dim {\0A dim_param: \22dim_0_1\22\0A }\0A }\0A}\0A], input: \22input_0\22\0Aoutput: \221\22\0Aname: \22/Bernoulli\22\0Aop_type: \22Bernoulli\22\0A)"(%arg0) : (!torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64>
return %0 : !torch.vtensor<[?,?],f64>
}
func.func private @"('Bernoulli', '', 20, [tensor_type {\0A elem_type: 11\0A shape {\0A dim {\0A dim_param: \22dim_0_0\22\0A }\0A dim {\0A dim_param: \22dim_0_1\22\0A }\0A }\0A}\0A], [tensor_type {\0A elem_type: 11\0A shape {\0A dim {\0A dim_param: \22dim_0_0\22\0A }\0A dim {\0A dim_param: \22dim_0_1\22\0A }\0A }\0A}\0A], input: \22input_0\22\0Aoutput: \221\22\0Aname: \22/Bernoulli\22\0Aop_type: \22Bernoulli\22\0A)"(%arg0: !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
%0 = torch.operator "onnx.RandomUniformLike"(%arg0) {torch.onnx.dtype = 11 : si64, torch.onnx.high = 1.000000e+00 : f32, torch.onnx.low = 0.000000e+00 : f32} : (!torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],f64>
%1 = torch.operator "onnx.Greater"(%0, %arg0) : (!torch.vtensor<[?,?],f64>, !torch.vtensor<[?,?],f64>) -> !torch.vtensor<[?,?],i1>
%2 = torch.operator "onnx.Cast"(%1) {torch.onnx.to = 11 : si64} : (!torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],f64>
return %2 : !torch.vtensor<[?,?],f64>
}
}
and the e2e tests for Bernoulli start failing:
$ python -m e2e_testing.main -f Bernoulli -c onnx
[...]
XFAIL - "BernoulliFloatModule_basic"
XFAIL - "BernoulliModule_basic"
FAIL - "BernoulliOnesModule_basic"
XFAIL - "BernoulliPModule_basic"
XFAIL - "BernoulliTensorModule_basic"
FAIL - "BernoulliZerosModule_basic"
Unexpected outcome summary: (onnx)
****** Failed tests - 2 tests
FAIL - "BernoulliOnesModule_basic"
FAIL - "BernoulliZerosModule_basic"
Summary:
Failed: 2
Expectedly Failed: 4
When I investigated why this happens, it seems to be that the ONNX function interprets the input to the operator (let's call it p) in the opposite way to what these tests expect. p is always in [0,1], but the ONNX function behaves like (1-p) was passed. So, where an all-ones result is expected, it gets all zeroes, and vice-versa.
Looking at the importer output above, we can see ONNX's definition is very simple: generate random numbers (each is in the range [0,1] I believe), then elementwise compare against p, with the comparison result (false or true) being casted to an integer (0 or 1). To get the "expected" behavior, greater-than would have to be replaced with a different comparison (perhaps less-than-or-equal).
To me, this surely indicates a bug, but I'm not sure which implementation is "wrong" and which is "right".