ImplementShiftNetworks: support rightward shifts as well
A layout change that reduces to a single rightward shift by 1 would require using all power-of-two rotation keys for a leftward shift, since it's evaluating -1 = 0b11111...1. Supporting rightward shifts as well as leftward shifts (perhaps in tandem with #2256) would fix this.
NAF based representation could be useful here. Intuitively it can help to reduce the collisions as well since the digit set extends from {0,1} to {-1,0,1} and less digits are used for the same representation.
While having NAF support is not a bad idea (citation: https://eprint.iacr.org/2021/1161), I think the simplest path forward is to give negative powers of two as shiftOrder input to VosVosErkinShiftNetworks. The randomized approach of #2256 could do this, but it may also be useful to guarantee a standard set of negative shifts is deterministically attempted.
(cf. https://github.com/google/heir/pull/2240 which hasn't yet been merged)
Related: https://github.com/google/heir/issues/744
This issue has 2 outstanding TODOs:
- tests/Dialect/TensorExt/Transforms/implement_shift_network.mlir:25: this test should only produce one rotation by -1
- tests/Transforms/layout_optimization/multiple_uses.mlir:4: without smartly picking power-of-two shits, this test breaks
This comment was autogenerated by todo-backlinks
I tried to look into this after implementing random network construction but it seems to be a lot more complicated than it sounds.
It's easy enough to either all-left or all-right shifts, represent the latter by negative numbers and then adapt ShiftStrategy::evaluate such that it works with right shifts as well. However, the conflict graph that results is then not isomorphic to the one you'd get via left-shifting and I couldn't figure out how to adapt to that.
I'm pretty sure I read at least parts of the relevant paper somewhere but I can't locate an open access version anymore so I'm not if it talks about this detail.
I don't think the paper describes how to handle negative shifts at all. My thought was indeed to just try all-right shifts by negating all the values of an attempted all-left shift run. Could you say more about the problems that arise when you try this? The graph will be different, yes, but I don't think the graph is where the complicated parts are. Instead, there are parts of the current code that hard-code that we're doing leftward shifts (for example, here when a "normalized shift" is constructed and here where we do a bitwise-and to determine if a power-of-two shift is needed as part of a larger overall shift (and also here)).
With negations of powers of two, you will either need to add special cases at each of these places, or add a layer of indirection to smooth over the difference. For instance, if you have a ciphertext size of 16 and a leftward shift of 9 starting from source index 3, if you have left-powers-of-two shifts, this would mean you need shifts {1, 8}. If you have rightward shifts, you'd need shifts of {-1, -2, -4}, and to determine that with bitwise operators you'd convert the canonical leftward shift to a canonical rightward shift, and bitwise-and it with the rightward shifts (or the negations of the negative leftward shifts).
A layer of indirection sounds like the right approach to me (and it could later be generalized to support more complex sets of shifts). I haven't thought through it in detail, but I imagine a class would encapsulate the chosen shifts and ordering, and the API would expose things like:
- Iterate over all the shifts in order
- Determine the subset of shifts needed to implement a "canonical leftward shift"
- Maybe something to accommodate applyVirtualRotation which heavily depends on the left-shift convention, though now that I look at it again, it seems like it just applies rotation arithmetic agnostically of the values...
@j2kun
With negations of powers of two, you will either need to add special cases at each of these places, or add a layer of indirection to smooth over the difference. For instance, if you have a ciphertext size of 16 and a leftward shift of 9 starting from source index 3, if you have left-powers-of-two shifts, this would mean you need shifts {1, 8}. If you have rightward shifts, you'd need shifts of {-1, -2, -4}, and to determine that with bitwise operators you'd convert the canonical leftward shift to a canonical rightward shift, and bitwise-and it with the rightward shifts (or the negations of the negative leftward shifts).
This is more or less what I tried to do. I'll dig up my changes again and see where the issue was.
@j2kun So here's what I had, changing the default shift order to all right shifts: https://github.com/Time0o/heir/commit/2ec51f6aee488ec1155fe3dba4da53621a93e2cd
When running the first shift network test it looks to me like this is constructing sensible rotation groups: https://gist.github.com/Time0o/d11c438b0b8811409d3d90b67c04fafd
But the test fails with:
Value of: checkMapping(mapping, numCts, ctSize)
Actual: false (which has these unexpected elements: { 1, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0 },
and doesn't have these expected elements: { 3, 15, 13, 9, 2, 7, 11, 5, 1, 14, 12, 4, 10, 0, 6, 8 })
I was not able to fully understand what is happening in RotationGroupKernel.h so I assume you're right and there is some adjustment necessary.
@Time0o I tweaked your patch slightly to make the tests pass, but still with hard coding right shifts. I think if we can convert the patch into something that can handle both left and right shifts, then it should be good. After the unit tests, the main test to check would be that @test_no_conflicts2 in tests/Dialect/TensorExt/Transforms/implement_shift_network.mlir produces a single rotation (it does under the patch below, but I didn't update the test).
https://gist.github.com/j2kun/0fa63b29381fb5c98a15c8a04030c6bc
In think the core complexity here is that:
- defaultShiftOrder and normalizeShift both assume leftward rotation, but are otherwise unconnected. Making them both do rightward or both do leftward in sync helps.
- The comparison of the required shift and the per-round rotation may differ on sign, and while I hacked it in that patch to take the abs of both sides, that wouldn't work if the sign of one of the two differed (it only works if the sign agrees on both sides)
- The part that actually applies the rotation group has some tricky edge cases related to rotations that split the rotated ciphertext across two target ciphertexts. There is calculates the "boundary" index (the last index before the rotated ciphertext spills over to the next ciphertext in this "virtual ciphertext" system) and that boundary index changes if you rotate right.
@j2kun Thank you, 3 seems quite difficult but 1&2 sound very doable. If you haven't already started on this I'll see what I can do.
Well the change in the diff for RotationGroupKernel.h suffices to handle 3, and it's probably general enough that it can be used as-is.
@j2kun That makes sense I think I somewhat misunderstood your last comment. I have tried adapting your change to allow initializing ShiftStrategy objects with a "shift kind" parameter which I think is more maintainable in this regard than allowing arbitrary shift orders and then somehow figuring out how to handle them. Could probably be abstracted further to allow more than all-left or all-right shifting, I'll try to create another patch tomorrow.