tensor_ext: Add canonicalization patterns for tensor_ext.rotate_and_reduce
Examples:
- Canonicalize a rotate_and_reduce with plaintext that all match constant 1 to remove the plaintext arg
- If there's a rotate_and_reduce preceded or followed by another rotation+mulop of the ciphertext, fold that in and increase the step count / add the plaintext mul arg to the reduction plaintexts
This issue has 1 outstanding TODOs:
- lib/Dialect/TensorExt/IR/TensorExtOps.td:253: Add canonicalization patterns
This comment was autogenerated by todo-backlinks
I thought, I'd give it a try and write some patterns. But I got stuck with the first one. I do have problems with $plaintexts being defined as an optional operand.
The compiler doesn't compile this pattern, because argument number mismatch: 3 in pattern vs. 4 in definition:
def DropAllOnesPlaintexts : Pat<
(TensorExt_RotateAndReduceOp $tensor,
(Arith_ConstantOp SplatElementsAttr:$constantAttr):$plaintexts,
$period,
$steps),
(TensorExt_RotateAndReduceOp $tensor, $period, $steps),
[(SplatValueIsOne $constantAttr)]
>;
Though you'll only get the error, if you change (Arith_ConstantOp SplatElementsAttr:$constantAttr):$plaintexts into $plaintexts, because the DAG construct in place of the optional operand causes this error first: use nested DAG construct to match op tensor_ext.rotate_and_reduce's variadic operand #1 unsupported now.
The only other use of an optional operand I could find in HEIR was in CKKSOps.td, but there are no TableGen rewrite patterns for the relinearize op. Maybe the optional operand requires this rewrite pattern to be written in C code?
The error seems like it's having trouble with the replacement, not the match, and it can't find a builder that has 3 arguments. You could try adding a custom builder for the case where no plaintexts operand is provided. See this pending PR for an example: https://github.com/google/heir/pull/2129/files#diff-343fe6f3498bfe996336022a6e2ce5acf52a47717bf7e5100c8d3601bb094bb5R227
In the worst case, it may simply be the case that DRR is not compatible with the optional operands, and you should just define the canonicalization patterns as C++ patterns.