Write a linalg.reduce kernel using the relation-based layout
With the new layout system, a linalg.reduce kernel is a bit more complicated, mainly due to the extra decisions possible with a multi-ciphertext packing.
In the old setup, we only allowed data to be packed into a single ciphertext, so the linalg.reduce kernel involved aligning the axis to be summed as a stride, and then doing a partial rotate-and-reduce to reduce them.
In the new setup, given an arbitrary input packing, we could utilize this same approach, or align slots across multiple ciphertexts, or a combination of both. I'm not sure what the right approach is (haven't thought about it too hard), but in general it may require repacking the ciphertext to make slots align.
What we should do is decide on some kernel options, add them to layout-optimization so it can handle the implied layout conversions required to make the kernel feasible, then implement the kernel chosen by the optimizer in ConvertToCiphertextSemantics for the assumed layout.
This issue has 1 outstanding TODOs:
- lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp:495: Implement a proper kernel
This comment was autogenerated by todo-backlinks
Good first step: force a layout conversion of the operands so that the slots are aligned per ciphertext, and so that linalg.reduce (for arith.addi/f) can be implemented just by summing ciphertexts elementwise.
We know this will produce a slow program if, say, the inputs to linalg.reduce come from something with a nontrivial packing like a matvec. However, this will get initial support into the compiler for linalg.reduce, and followup work would involve some specific knowledge of the layout relation to determine a better kernel.
@j2kun I’m currently working on this. I’ve assumed the operand layout to be row-major, with each row represented as a ciphertext. For reduction along dimension 0, I perform elementwise addition across ciphertexts, and for reduction along dimension 1, I apply a rotate-and-reduce operation within each row, followed by stacking the results. Could you please confirm if this approach aligns with the intended direction for this issue?
I think this aligns with the idea in https://github.com/google/heir/issues/2254#issuecomment-3452260868, though if the data tensor is 3-dimensional this will not suffice, even with a row-major layout (this is part of why I think the problem is not so trivial, and filed an issue instead of solving it myself!).
I wrote up that comment earlier today because we met with someone in the Monday morning office hours who also wanted to work on this problem. I hope she is reading this thread to know someone else is also working on it (maybe you two could compare notes as well).