[frontend] support type hints in python code
I had some code for the AES transciphering demo - I know it's on my table to write a redux of the full thing, but I wanted to post this issue so i can work on it.
There's an efficient sbox lookup implementation by bayer-poralta that uses bitwise operators. Here's an implementation; https://www.bearssl.org/gitweb/?p=BearSSL;a=blob;f=src/symcipher/aes_ct.c;h=66776d9e206c92cbbf3fa799c651a3d3652bd75a;hb=HEAD#l29
I could almost copy and paste the body (the bit assignments were more complicated), but the annoying part was that all the arithmetic ended up in i64 in the frontend. Could be pretty easily solvable if I could give a type hint to the x0, ... , x7's as i1s. Something like
def sub_bytes(x: Secret[I8]):
x0: I1 = (x >> 7) & 1
x1: I1 = (x >> 6) & 1
x2: I1 = (x >> 5) & 1
x3: I1 = (x >> 4) & 1
...
Mh, that's odd - I thought I had overridden the constant behavior in numba to where it'd set integer constants to the smallest size in [i8,i16,i32,i64] that they'd fit into, so I'd have expected to see the constant be an i8, but it's an %0 = arith.constant 7 : i32 ❓
Found the bug, I'm an absolute idiot, PR incoming... EDIT: See #1794
Concerning the actual type-hint/issue: what do you want the I1 type hint to do here? It seems like you want I8 (which my fix would give you) but if you actually want this to be cast down to Bool, I don't think a type hint would be sufficient for that, we'd need something like x0 = cast( (x >> 7) & 1, I1)
From the PR:
@asraa:
I'm extracting the 8th bit of an i8 here, and want to truncate that to an i1 MLIR type so that the following bitwise arithmetic would happen under i1. Instead the python frontend that thinks x is i8 will keep all arithmetic in i8 (and then sometimes to i64 which is probably another bug)
Oh, I see - I'm not sure type hints are the right way to go about this (a custom numba intrinsic cast(<value>,<type>) might be easier), but I'll see what I can do!
Oh, I see - I'm not sure type hints are the right way to go about this (a custom numba intrinsic cast(
, ) might be easier), but I'll see what I can do!
Ah! OK - a cast -> trunci might be the right thing, you're right. It was also somewhat a thought towards a user specifying semantics of a program too (does an i8 * i8 result end up an i8 or an i16?)
Ok, so on the frontend/numba side it seems to actually be pretty easy to do "casts" within their existing numba.extending system.
This would allow you to write
def sub_bytes(x: Secret[I8]):
x0 = I1((x >> 7) & 1)
x1 = I1((x >> 6) & 1)
x2 = I1((x >> 5) & 1)
x3 = I1((x >> 4) & 1)
...
This means the I1, I8, etc classes are no longer pure annotations, but can actually take part in python code (for non-FHE runs) though they're just pretty simple forwarding functions (See https://github.com/AlexanderViand/heir/commit/8634e1f1367b2081a9ff525f12f82bd9f970585d )
On the numba side, I1(x)/I8(x)/etc is simply a function that always outputs a bool/int8/etc, so type inference works as expected.
For actually producing MLIR, though, there's a bit of work to do, though the casts appear as function calls, so it shouldn't be too hard to emit the right truncation/extension/integer-to-float/float-to-int op.
Currently, there's nothing to stop one from writing I1("some_string") though I'll probably add some sanity checks to both the Python types (so that a non-FHE run of this complains) and to the MLIR emitter (so that we don't expose exotic MLIR verifier failures in the trunc/ext/etc ops to the user).
Oh! That's amazing- that works great (as far as sanity check goes, i think some of that same logic of what to emit could be reused - e.g. if we construct a dictionary of [to->from] : op, checking if a string -> i1 conversion op exists would fail.
This is great, thanks!!
I had some time to work on this today and mostly finished the frontend side of this but realized this brings up some questions about how we handle these various types in the arithmetic world:
- The noise management doesn't know about
arith.trunci/truncf/sitofp/uitofp/fptoui/fptosiyet, and I'm not sure if simple adding them to the "exception lists" (e.g., here) of ignored operations would be correct? - How would we actually lower a truncation in BGV/BFV or CKKS? So far, we've only handled extensions, which we treated as no-ops, assuming the plaintext space would anyway already be large enough. I guess we'd have to do either interpolation or approximation, but in the latter case I'm not sure how truncation would fit with our intended guarantee of "if your original program does not overflow, it also won't overflow in FHE".
- For int-to-float and float-to-int, this also brings up the more complicated issue of data semantics for float vs fixed point.
To get AES working in the frontend, we can just add it and throw an error message in heir-opt when encountering these ops in a non-CGGI pipeline, but I wanted to document my thoughts/questions somewhere.
OK to speak back to what we do want:
In CGGI land, we want these casts to get lowered to CGGI library casts https://github.com/google/heir/blob/55d46963b825837e4505c832ceb3f8fd1b7af491/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp#L185
That pipeline happens in the non-Yosys side of things. In the Yosys side of things, again, arith extui /trunci are also implemented in verilog and will yield better circuits too.
With the cast-system introduced in #1825, I think it's safe to close this one for now, at least until we find a need for actual type hints
@AlexanderViand i don't have the python i scrapped from the bitwise C implementation above, but that was the meat of the AES implementation.
I think that AES funcs (in the common/aes folder) also require tensor slicing, which would be the other thing I'm not sure works in the python frontend