heir icon indicating copy to clipboard operation
heir copied to clipboard

[frontend] support type hints in python code

Open asraa opened this issue 7 months ago • 7 comments

I had some code for the AES transciphering demo - I know it's on my table to write a redux of the full thing, but I wanted to post this issue so i can work on it.

There's an efficient sbox lookup implementation by bayer-poralta that uses bitwise operators. Here's an implementation; https://www.bearssl.org/gitweb/?p=BearSSL;a=blob;f=src/symcipher/aes_ct.c;h=66776d9e206c92cbbf3fa799c651a3d3652bd75a;hb=HEAD#l29

I could almost copy and paste the body (the bit assignments were more complicated), but the annoying part was that all the arithmetic ended up in i64 in the frontend. Could be pretty easily solvable if I could give a type hint to the x0, ... , x7's as i1s. Something like

def sub_bytes(x: Secret[I8]):
    x0: I1 = (x >> 7) & 1
    x1: I1 = (x >> 6) & 1
    x2: I1 = (x >> 5) & 1
    x3: I1 = (x >> 4) & 1
...

asraa avatar May 05 '25 16:05 asraa

Mh, that's odd - I thought I had overridden the constant behavior in numba to where it'd set integer constants to the smallest size in [i8,i16,i32,i64] that they'd fit into, so I'd have expected to see the constant be an i8, but it's an %0 = arith.constant 7 : i32

AlexanderViand avatar May 06 '25 17:05 AlexanderViand

Found the bug, I'm an absolute idiot, PR incoming... EDIT: See #1794

Concerning the actual type-hint/issue: what do you want the I1 type hint to do here? It seems like you want I8 (which my fix would give you) but if you actually want this to be cast down to Bool, I don't think a type hint would be sufficient for that, we'd need something like x0 = cast( (x >> 7) & 1, I1)

AlexanderViand avatar May 06 '25 17:05 AlexanderViand

From the PR:

@asraa:

I'm extracting the 8th bit of an i8 here, and want to truncate that to an i1 MLIR type so that the following bitwise arithmetic would happen under i1. Instead the python frontend that thinks x is i8 will keep all arithmetic in i8 (and then sometimes to i64 which is probably another bug)

Oh, I see - I'm not sure type hints are the right way to go about this (a custom numba intrinsic cast(<value>,<type>) might be easier), but I'll see what I can do!

AlexanderViand avatar May 06 '25 19:05 AlexanderViand

Oh, I see - I'm not sure type hints are the right way to go about this (a custom numba intrinsic cast(,) might be easier), but I'll see what I can do!

Ah! OK - a cast -> trunci might be the right thing, you're right. It was also somewhat a thought towards a user specifying semantics of a program too (does an i8 * i8 result end up an i8 or an i16?)

asraa avatar May 06 '25 19:05 asraa

Ok, so on the frontend/numba side it seems to actually be pretty easy to do "casts" within their existing numba.extending system.

This would allow you to write

def sub_bytes(x: Secret[I8]):
    x0 = I1((x >> 7) & 1)
    x1 = I1((x >> 6) & 1)
    x2 = I1((x >> 5) & 1)
    x3 = I1((x >> 4) & 1)
    ...

This means the I1, I8, etc classes are no longer pure annotations, but can actually take part in python code (for non-FHE runs) though they're just pretty simple forwarding functions (See https://github.com/AlexanderViand/heir/commit/8634e1f1367b2081a9ff525f12f82bd9f970585d )

On the numba side, I1(x)/I8(x)/etc is simply a function that always outputs a bool/int8/etc, so type inference works as expected.

For actually producing MLIR, though, there's a bit of work to do, though the casts appear as function calls, so it shouldn't be too hard to emit the right truncation/extension/integer-to-float/float-to-int op.

Currently, there's nothing to stop one from writing I1("some_string") though I'll probably add some sanity checks to both the Python types (so that a non-FHE run of this complains) and to the MLIR emitter (so that we don't expose exotic MLIR verifier failures in the trunc/ext/etc ops to the user).

AlexanderViand avatar May 06 '25 23:05 AlexanderViand

Oh! That's amazing- that works great (as far as sanity check goes, i think some of that same logic of what to emit could be reused - e.g. if we construct a dictionary of [to->from] : op, checking if a string -> i1 conversion op exists would fail.

This is great, thanks!!

asraa avatar May 08 '25 15:05 asraa

I had some time to work on this today and mostly finished the frontend side of this but realized this brings up some questions about how we handle these various types in the arithmetic world:

  • The noise management doesn't know about arith.trunci/truncf/sitofp/uitofp/fptoui/fptosi yet, and I'm not sure if simple adding them to the "exception lists" (e.g., here) of ignored operations would be correct?
  • How would we actually lower a truncation in BGV/BFV or CKKS? So far, we've only handled extensions, which we treated as no-ops, assuming the plaintext space would anyway already be large enough. I guess we'd have to do either interpolation or approximation, but in the latter case I'm not sure how truncation would fit with our intended guarantee of "if your original program does not overflow, it also won't overflow in FHE".
  • For int-to-float and float-to-int, this also brings up the more complicated issue of data semantics for float vs fixed point.

To get AES working in the frontend, we can just add it and throw an error message in heir-opt when encountering these ops in a non-CGGI pipeline, but I wanted to document my thoughts/questions somewhere.

AlexanderViand avatar May 16 '25 22:05 AlexanderViand

OK to speak back to what we do want:

In CGGI land, we want these casts to get lowered to CGGI library casts https://github.com/google/heir/blob/55d46963b825837e4505c832ceb3f8fd1b7af491/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp#L185

That pipeline happens in the non-Yosys side of things. In the Yosys side of things, again, arith extui /trunci are also implemented in verilog and will yield better circuits too.

asraa avatar Jun 09 '25 17:06 asraa

With the cast-system introduced in #1825, I think it's safe to close this one for now, at least until we find a need for actual type hints

AlexanderViand avatar Jul 20 '25 00:07 AlexanderViand

@AlexanderViand i don't have the python i scrapped from the bitwise C implementation above, but that was the meat of the AES implementation.

I think that AES funcs (in the common/aes folder) also require tensor slicing, which would be the other thing I'm not sure works in the python frontend

asraa avatar Oct 06 '25 16:10 asraa