heir icon indicating copy to clipboard operation
heir copied to clipboard

ImplementShiftNetworks: support rightward shifts as well

Open j2kun opened this issue 3 months ago • 12 comments

A layout change that reduces to a single rightward shift by 1 would require using all power-of-two rotation keys for a leftward shift, since it's evaluating -1 = 0b11111...1. Supporting rightward shifts as well as leftward shifts (perhaps in tandem with #2256) would fix this.

j2kun avatar Sep 24 '25 23:09 j2kun

NAF based representation could be useful here. Intuitively it can help to reduce the collisions as well since the digit set extends from {0,1} to {-1,0,1} and less digits are used for the same representation.

eymay avatar Sep 25 '25 12:09 eymay

While having NAF support is not a bad idea (citation: https://eprint.iacr.org/2021/1161), I think the simplest path forward is to give negative powers of two as shiftOrder input to VosVosErkinShiftNetworks. The randomized approach of #2256 could do this, but it may also be useful to guarantee a standard set of negative shifts is deterministically attempted.

(cf. https://github.com/google/heir/pull/2240 which hasn't yet been merged)

j2kun avatar Sep 25 '25 16:09 j2kun

Related: https://github.com/google/heir/issues/744

j2kun avatar Sep 25 '25 20:09 j2kun

This issue has 2 outstanding TODOs:

This comment was autogenerated by todo-backlinks

github-actions[bot] avatar Sep 26 '25 21:09 github-actions[bot]

I tried to look into this after implementing random network construction but it seems to be a lot more complicated than it sounds.

It's easy enough to either all-left or all-right shifts, represent the latter by negative numbers and then adapt ShiftStrategy::evaluate such that it works with right shifts as well. However, the conflict graph that results is then not isomorphic to the one you'd get via left-shifting and I couldn't figure out how to adapt to that.

I'm pretty sure I read at least parts of the relevant paper somewhere but I can't locate an open access version anymore so I'm not if it talks about this detail.

Time0o avatar Oct 04 '25 14:10 Time0o

I don't think the paper describes how to handle negative shifts at all. My thought was indeed to just try all-right shifts by negating all the values of an attempted all-left shift run. Could you say more about the problems that arise when you try this? The graph will be different, yes, but I don't think the graph is where the complicated parts are. Instead, there are parts of the current code that hard-code that we're doing leftward shifts (for example, here when a "normalized shift" is constructed and here where we do a bitwise-and to determine if a power-of-two shift is needed as part of a larger overall shift (and also here)).

With negations of powers of two, you will either need to add special cases at each of these places, or add a layer of indirection to smooth over the difference. For instance, if you have a ciphertext size of 16 and a leftward shift of 9 starting from source index 3, if you have left-powers-of-two shifts, this would mean you need shifts {1, 8}. If you have rightward shifts, you'd need shifts of {-1, -2, -4}, and to determine that with bitwise operators you'd convert the canonical leftward shift to a canonical rightward shift, and bitwise-and it with the rightward shifts (or the negations of the negative leftward shifts).

A layer of indirection sounds like the right approach to me (and it could later be generalized to support more complex sets of shifts). I haven't thought through it in detail, but I imagine a class would encapsulate the chosen shifts and ordering, and the API would expose things like:

  • Iterate over all the shifts in order
  • Determine the subset of shifts needed to implement a "canonical leftward shift"
  • Maybe something to accommodate applyVirtualRotation which heavily depends on the left-shift convention, though now that I look at it again, it seems like it just applies rotation arithmetic agnostically of the values...

j2kun avatar Oct 04 '25 20:10 j2kun

@j2kun

With negations of powers of two, you will either need to add special cases at each of these places, or add a layer of indirection to smooth over the difference. For instance, if you have a ciphertext size of 16 and a leftward shift of 9 starting from source index 3, if you have left-powers-of-two shifts, this would mean you need shifts {1, 8}. If you have rightward shifts, you'd need shifts of {-1, -2, -4}, and to determine that with bitwise operators you'd convert the canonical leftward shift to a canonical rightward shift, and bitwise-and it with the rightward shifts (or the negations of the negative leftward shifts).

This is more or less what I tried to do. I'll dig up my changes again and see where the issue was.

Time0o avatar Oct 05 '25 08:10 Time0o

@j2kun So here's what I had, changing the default shift order to all right shifts: https://github.com/Time0o/heir/commit/2ec51f6aee488ec1155fe3dba4da53621a93e2cd

When running the first shift network test it looks to me like this is constructing sensible rotation groups: https://gist.github.com/Time0o/d11c438b0b8811409d3d90b67c04fafd

But the test fails with:

Value of: checkMapping(mapping, numCts, ctSize)                                     
  Actual: false (which has these unexpected elements: { 1, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0 },
and doesn't have these expected elements: { 3, 15, 13, 9, 2, 7, 11, 5, 1, 14, 12, 4, 10, 0, 6, 8 })

I was not able to fully understand what is happening in RotationGroupKernel.h so I assume you're right and there is some adjustment necessary.

Time0o avatar Oct 05 '25 17:10 Time0o

@Time0o I tweaked your patch slightly to make the tests pass, but still with hard coding right shifts. I think if we can convert the patch into something that can handle both left and right shifts, then it should be good. After the unit tests, the main test to check would be that @test_no_conflicts2 in tests/Dialect/TensorExt/Transforms/implement_shift_network.mlir produces a single rotation (it does under the patch below, but I didn't update the test).

https://gist.github.com/j2kun/0fa63b29381fb5c98a15c8a04030c6bc

In think the core complexity here is that:

  1. defaultShiftOrder and normalizeShift both assume leftward rotation, but are otherwise unconnected. Making them both do rightward or both do leftward in sync helps.
  2. The comparison of the required shift and the per-round rotation may differ on sign, and while I hacked it in that patch to take the abs of both sides, that wouldn't work if the sign of one of the two differed (it only works if the sign agrees on both sides)
  3. The part that actually applies the rotation group has some tricky edge cases related to rotations that split the rotated ciphertext across two target ciphertexts. There is calculates the "boundary" index (the last index before the rotated ciphertext spills over to the next ciphertext in this "virtual ciphertext" system) and that boundary index changes if you rotate right.

j2kun avatar Oct 13 '25 04:10 j2kun

@j2kun Thank you, 3 seems quite difficult but 1&2 sound very doable. If you haven't already started on this I'll see what I can do.

Time0o avatar Oct 14 '25 10:10 Time0o

Well the change in the diff for RotationGroupKernel.h suffices to handle 3, and it's probably general enough that it can be used as-is.

j2kun avatar Oct 14 '25 14:10 j2kun

@j2kun That makes sense I think I somewhat misunderstood your last comment. I have tried adapting your change to allow initializing ShiftStrategy objects with a "shift kind" parameter which I think is more maintainable in this regard than allowing arbitrary shift orders and then somehow figuring out how to handle them. Could probably be abstracted further to allow more than all-left or all-right shifting, I'll try to create another patch tomorrow.

Time0o avatar Oct 14 '25 18:10 Time0o