secret: segfault in canonicalize with constant
Reproducer:
func.func @func() -> !secret.secret<i2> {
%false = arith.constant false
%0 = secret.generic ins(%false : i1) {
^bb0(%arg1: i1):
%1 = tensor.from_elements %arg1, %arg1 : tensor<2xi1>
secret.yield %1 : tensor<2xi1>
} -> !secret.secret<tensor<2xi1>>
%1 = secret.cast %0 : !secret.secret<tensor<2xi1>> to !secret.secret<i2>
return %1 : !secret.secret<i2>
}
When heir-opt --canonicalize is run, then the following yields:
WARNING: Logging before InitGoogle() is written to STDERR
F0000 00:00:1747150597.830243 416504 logging.cc:61] assert.h assertion failed at third_party/llvm/llvm-project/llvm/include/llvm/Support/Casting.h:566 in decltype(auto) llvm::cast(const From &) [To = mlir::detail::TypedValue<mlir::heir::secret::SecretType>, From = mlir::Value]: isa<To>(Val) && "cast<Ty>() argument of incompatible type!"
*** Check failure stack trace: ***
@ 0x558b05f266c9 absl::log_internal::LogMessage::PrepareToDie()
@ 0x558b05f25c67 absl::log_internal::LogMessage::SendToLog()
@ 0x558b05f25c1b absl::log_internal::LogMessage::Flush()
@ 0x558af7d151f5 absl::log_internal::Voidify::operator&&<>()
@ 0x558b05e4f26a __assert_fail
@ 0x558af7ac7f1e llvm::cast<>()
@ 0x558af7ac6f5b mlir::heir::secret::CastOp::getInput()
@ 0x558afd903ed2 mlir::heir::secret::CastOp::fold()
@ 0x558afd8f51d8 mlir::Op<>::foldSingleResultHook<>()
@ 0x558afd8f511c mlir::Op<>::getFoldHookFn()::{lambda()#1}::operator()()
@ 0x558afd8f50ae llvm::detail::UniqueFunctionBase<>::CallImpl<>()
@ 0x558af724b96e llvm::unique_function<>::operator()()
@ 0x558afd8f4475 mlir::RegisteredOperationName::Model<>::foldHook()
@ 0x558afedb6afc mlir::OperationName::foldHook()
@ 0x558afedad461 mlir::Operation::fold()
@ 0x558afedada2b mlir::Operation::fold()
@ 0x558afe1d8eed (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist()
@ 0x558afe1d87a0 (anonymous namespace)::RegionPatternRewriteDriver::simplify()::$_2::operator()()
@ 0x558afe1d8768 llvm::function_ref<>::callback_fn<>()
@ 0x558afab661bc llvm::function_ref<>::operator()()
@ 0x558afe1d7b6d mlir::MLIRContext::executeAction<>()
@ 0x558afe1d5156 (anonymous namespace)::RegionPatternRewriteDriver::simplify()
@ 0x558afe1d4c43 mlir::applyPatternsGreedily()
@ 0x558af7255fac mlir::applyPatternsGreedily()
@ 0x558afe109f97 (anonymous namespace)::Canonicalizer::runOnOperation()
@ 0x558afe91c889 mlir::detail::OpToOpPassAdaptor::run()::$_1::operator()()
@ 0x558afe91c7f8 llvm::function_ref<>::callback_fn<>()
@ 0x558afab661bc llvm::function_ref<>::operator()()
@ 0x558afe91f7cd mlir::MLIRContext::executeAction<>()
@ 0x558afe91128d mlir::detail::OpToOpPassAdaptor::run()
@ 0x558afe911cab mlir::detail::OpToOpPassAdaptor::runPipeline()
@ 0x558afe9143df mlir::PassManager::runPasses()
@ 0x558afe914285 mlir::PassManager::run()
@ 0x558afab35d30 performActions()
@ 0x558afab357e7 processBuffer()
@ 0x558afab35534 mlir::MlirOptMain()::$_0::operator()()
@ 0x558afab35488 llvm::function_ref<>::callback_fn<>()
@ 0x558afee487ac llvm::function_ref<>::operator()()
@ 0x558afee47bb2 mlir::splitAndProcessBuffer()
@ 0x558afab30386 mlir::MlirOptMain()
@ 0x558afab30888 mlir::MlirOptMain()
@ 0x558afab30aab mlir::MlirOptMain()
@ 0x558af7138db1 main
@ 0x7f9cafafd3d4 __libc_start_main
@ 0x558af7137dea _start
Running in debug mode, here are the highlights:
RemoveNonSecretGenericArgsremoves the%falseargument:
func.func @func() -> !secret.secret<i2> {
%false = arith.constant false
%0 = secret.generic {
%from_elements = tensor.from_elements %false, %false : tensor<2xi1>
secret.yield %from_elements : tensor<2xi1>
} -> !secret.secret<tensor<2xi1>>
%1 = secret.cast %0 : !secret.secret<tensor<2xi1>> to !secret.secret<i2>
return %1 : !secret.secret<i2>
}
tensor.from_elementswas folded:
func.func @func() -> !secret.secret<i2> {
%false = arith.constant false
%0 = secret.generic {
%cst = arith.constant dense<false> : tensor<2xi1>
secret.yield %cst : tensor<2xi1>
} -> !secret.secret<tensor<2xi1>>
%1 = secret.cast %0 : !secret.secret<tensor<2xi1>> to !secret.secret<i2>
return %1 : !secret.secret<i2>
}
mlir::heir::secret::CollapseSecretlessGenericmatches, so collapse this to
"func.func"() <{function_type = () -> !secret.secret<i2>, sym_name = "func"}> ({
%0 = "arith.constant"() <{value = dense<false> : tensor<2xi1>}> : () -> tensor<2xi1>
%1 = "secret.cast"(%0) : (tensor<2xi1>) -> !secret.secret<i2>
"func.return"(%1) : (!secret.secret<i2>) -> ()
}) : () -> ()
where we now have a verification issue because secret.cast has a non-secret input.
Then, the folding step yields a segfault since it expects secret.cast's input to be secret.
My fix is to create a pattern that converts constants yielded to conceal ops - i'll put it up for review for debate.
But my fix breaks this test: https://github.com/google/heir/blob/a422f130fa42ed2acf319a1bd6e420b651863c03/tests/Dialect/Secret/Transforms/canonicalize/secretless.mlir#L4
which now seems to question this pass. i would expect we only collapse them when it returns plaintext types (which isn't even allowed). why would this collapse and change the return type of the func?