Compare SOTA FHE polynomial approximations with HEIR's CF approximation
Given that ReLU(x) = x (0.5 + 0.5 sgn(x)), this reduces to approximating the sign function, and this paper appears to have the state of the art: https://eprint.iacr.org/2020/834
Also note
max(u, v) = ((u+v) + (u-v)sign(u-v)) / 2min(u, v) = -max(-u, -v) = ((u+v) - (v - u)sign(v - u)) / 2
Also cf. https://openreview.net/pdf?id=Hq16Jk2bVlp and https://eprint.iacr.org/2021/1688 which use these approximations.
The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124
The paper https://eprint.iacr.org/2019/1234 is a precursor to https://eprint.iacr.org/2020/834, but also seems to explain more of the motivation behind the composite polynomials.
An example of generating a well-fitting polynomial using lolremez: https://github.com/samhocevar/lolremez/issues/28#issuecomment-1913324892
Another tool: https://github.com/pychebfun/pychebfun
Outline sketch:
- Use an existing Remez solver to find any old approximation to sgn or max of some arbitrary degree, e.g., from https://github.com/google/heir/issues/658#issuecomment-2087937786
- Implement a lowering to that fixed polynomial, and run an e2e test
- Use a Paterson-Stockmeyer approach to minimize mul ops (https://www.csd.uwo.ca/~mmorenom/HPCA-ACA-2017/Sivan_Toledo.ACA-2017-Talk.pdf has some notes, looking for a better source)
Various improvements based on more recent research that would be worth splitting into separate tickets.
- Multi-interval Remez solver for a better sgn approximation: Algorithm 2 of https://eprint.iacr.org/2020/834
- Domain extension polynomial to improve sqrt(n) muls to log(n): https://eprint.iacr.org/2022/280
Thanks to Seonhong Min for sending me https://eprint.iacr.org/2018/462, in which it shows that BFV achieves polynomial approximations via a fixed-point approximation, not a floating point one. I think there is also some complexity there in that evaluating a fixed point polynomial approximation also requires a rounding step, but not always, see sec 2.5
I also had an interest in this issue a while back, so I know a paper worth sharing: https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=10155408 It's a follow-up paper by the authors of https://eprint.iacr.org/2020/834, and there's an implementation: https://github.com/snu-ccl/approxCNN/tree/main. In the repo, you'll find that they've hardcoded the coefficients of polynomial approximations of sign function from alpha=4 to 14!
The paper linked above 2020/834 does not release source code, but there is a reference implementation in LattiGo: https://github.com/tuneinsight/lattigo/blob/4cce9a48c1daaa2dd122921822f5ad70cd444156/he/hefloat/minimax_composite_polynomial.go#L124
Oh, I didn't know there was an implementation in Lattigo! In that case, these hardcoded coefficients might not be necessary.
After starting an implementation in https://github.com/google/heir/pull/665 (with a fixed approximation polynomial) and discussing in the HEIR meeting today, I have a few kinks to work out:
- I want to separate the task of choosing a polynomial approximation from the optimizations around evaluating it. This implies:
- I need a floating-point representation of a polynomial in the IR, but PolynomialAttr currently only supports integer coefficients
- I need a new op, say
poly_ext.evalwhose operands are the polynomial to apply and its input - The approximation itself is for sign, but these operations are actually applied to
max(0, x) = (x + x * sign(x)) / 2, which means we should support some internal polynomial arithmetic to construct these from the checked-in approximation. We meant to do this to support constant folding in the polynomial dialect, but never got around to it.
- (1) has a twist in that many more advanced polynomial approximations are not represented literally, but implicitly as a composition of smaller degree polynomials. This implies I will need a
polynomial.composeop, or else an attribute that supports composite sub-polynomials, and cannot limit (1.ii) above to a single static polynomial. I think I will start with a single static polynomial but try to avoid making it difficult to upgrade to a composite polynomial. - The approximate polynomial itself has a few quirks, because its coefficients further need to be encoded in a scheme-specific fashion. For CKKS this is relatively straightforward, but introduces additional error. For BGV/BFV this seems much harder, in part because the encodings are fixed-point and hence require rounding during polynomial evaluation, but rounding itself is hard (see above). There is also a question about which basis the polynomial is expressed in, cf. https://discord.com/channels/901152454077452399/1235349479482196049 for more on this
- The above points expose a problem with "lowering a ReLU": at the tosa level we don't yet know what scheme will be chosen, so the choice of polynomial approximation can't be scheme-aware or encoding-aware. I think the right solution here will be to include some extra metadata on the polynomial to express what function is being approximated, so that we can re-approximate it at lower levels if necessary.
These folks do something slightly different, which is more holistic in re NN training: https://github.com/EfficientFHE/SmartPAF
They pick small degree polynomial approximations and then do a variety of network fine-tuning to adapt the network to the replaced operations. This seems out of scope of the compiler, since it would require training data to be included.
I added an upstream RFC for the polynomial approximation pass https://discourse.llvm.org/t/rfc-a-polynomial-approximation-pass/79301
@jianmingTONG has also worked on https://github.com/EfficientFHE/SmartPAF system for approximating activation functions