heir icon indicating copy to clipboard operation
heir copied to clipboard

Write-up for MLP inference on MNIST using HEIR CKKS pipeline

Open ZenithalHourlyRate opened this issue 11 months ago • 8 comments

MLP is one of the simplest form of neural network, and supporting it might be a starter pointer for other more complex neural network (e.g. the widely-used ResNet benchmark in HE literature).

Indeed such support could be imported from frontend like TOSA/STABLEHLO, but as #1079 suggest, HEIR currently does not support some linalg op so we have to manually write it. Also, there are still some limitation on the HEIR CKKS pipeline so manual intervention is a must.

The essential code is in https://github.com/ZenithalHourlyRate/heir/commit/617e0d163690710c5ed44840553e9193d936ea33, including the MLIR MLP impl and corresponding LLVM/OpenFHE driver.

The weight/test data needed is in https://cloud.tsinghua.edu.cn/f/b2da50f8bdbc4aa1859f/

Network Design

MNIST is a dataset of handwritten digits (image of 28x28), and MLP could be used to classify them (into 10 labels); the accuracy is often good enough, around 95%.

The typical MLP design involves the following part:

  • The first Fully-Connected layer, of size 784x512.
  • The activation layer, using RELU as the activation function.
  • The second Fully-Connected layer, of size 512x10.

In HEIR, we do have implemented Halevi-Shoup matrix multiplication, but we only support square matrix for now; also, RELU could be hardly expressed in HE primitive, so a polynomial approximation is needed (check #658, #665 and #1217).

So, the specialized network design becomes

  • The first FC layer, of size 1024x1024
  • The activation layer, using Approx-RELU based on polynomial appox
  • The second FC layer, of size 1024x1024

Padding is added accordingly.

Training

Training of this specialized network is done by @yushijinhun using pytorch, achieving accuracy of 96%.

Inference.

There are two version of inference MLIR impl in the code above. The mlp.mlir is to show that the cleartext computation itself is correct (lower to LLVM and run with C function), and the mlp_inline.mlir could be accepted by the HEIR pipeline and produce a OpenFHE code.

Cleartext computation (for verifying we are correct)

mlp.mlir contains the following code

func.func @approx_sign(%x: tensor<1x1024xf32>) -> tensor<1x1024xf32>
// detailed impl

func.func @approx_relu(%x: tensor<1x1024xf32>) -> tensor<1x1024xf32>
// detailed impl

func.func @mlp(%input: tensor<1x1024xf32>, %fc1: tensor<1024x1024xf32>, %fc2: tensor<1024x1024xf32>, %fc1_buffer: tensor<1x1024xf32>, %fc2_buffer: tensor<1x1024xf32>) -> tensor<1x1024xf32> attributes {llvm.emit_c_interface} {
  %fc1_result = linalg.matmul ins(%input, %fc1 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%fc1_buffer : tensor<1x1024xf32>) -> tensor<1x1024xf32>
  %relu1 = call @approx_relu(%fc1_result) : (tensor<1x1024xf32>) -> tensor<1x1024xf32>
  %fc2_result = linalg.matmul ins(%relu1, %fc2 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%fc2_buffer : tensor<1x1024xf32>) -> tensor<1x1024xf32>
  return %fc2_result : tensor<1x1024xf32>
}

The corresponding C interface is

extern "C" {
void *_mlir_ciface_mlp(MemRefDescriptor<float, 2> *output,
                       MemRefDescriptor<float, 2> *input,
                       MemRefDescriptor<float, 2> *fc1,
                       MemRefDescriptor<float, 2> *fc2,
                       MemRefDescriptor<float, 2> *fc1_buffer,
                       MemRefDescriptor<float, 2> *fc2_buffer);
}

The reason we use such function signature is that when lowered to LLVM (then interfaced with C), we have to deal with memref.alloc and ambiguous ownership can lead to memory bug.

The lowering and compiling is done by the following step

heir-opt --heir-polynomial-to-llvm mlp.mlir | mlir-translate --mlir-to-llvmir | llc --relocation-model=pic -o mlp.s
clang++ -o mlp_main mlp.cpp mlp.s libmlir_c_runner_utils.so -fPIE

Running ./mlp_main we will get an accuracy of 9634/10000.

Homomorphic computation

To convert things to homormophic domain, especially for the HEIR CKKS pipeline, the following transformation is done

# HEIR does not support func.call now so inline everything
mlir-opt --inline mlp.mlir > mlir_inline.mlir
# then delete @approx_sign and @approx_relu function in it

The tricky part of current matmul impl is that its weight/buffer should all be arith.constant instead of argument, so we make the following change to mlir_inline.mlir. Note that the filled weight is extracted from the weight above, so we end up with a 24MB mlir_inline.mlir

%buffer0 = arith.constant dense<0.0> : tensor<1x1024xf32>
%weight0 = arith.constant dense<[[ /* fill weight here! */ ]]> : tensor<1024x1024xf32> 
linalg.matmul ins(%arg0, %weight0 : tensor<1x1024xf32>, tensor<1024x1024xf32>) outs(%buffer0 : tensor<1x1024xf32>) -> tensor<1x1024xf32>

// similarly for fc2

Then we need to change the function signature to

func.func @mlp(%arg0: tensor<1x1024xf32> {secret.secret}) -> tensor<1x1024xf32>

Then we do the following transform

# Note that due to large size in input, all the following command is slow
# this specific command takes ~10 minutes
heir-opt -- --mlir-to-openfhe-ckks="entry-function=mlp" mlp_inline.mlir > mlp_openfhe.mlir
## or detailed command below to inspect each step
## Halevi-Shoup matmul
#heir-opt --mlir-to-secret-arithmetic mlp_inline.mlir > mlp_secret_arithmetic.mlir
## Annotate RNS information, it shows that we need 9 level to finish computing.
#heir-opt --secret-insert-mgmt-ckks mlp_small_secret_arithmetic.mlir > mlp_mgmt.mlir
## lower to CKKS dialect
#heir-opt --secret-distribute-generic --secret-to-ckks mlp_mgmt.mlir > mlp_ckks.mlir
## lower to OpenFHE dialect
#heir-opt --lwe-add-client-interface="use-public-key=true one-value-per-helper-fn=true" --ckks-to-lwe --lwe-to-openfhe --openfhe-configure-crypto-context="entry-function=mlp" mlp_ckks.mlir > mlp_openfhe.mlir
# Translate to C++ openfhe function
# Note that the weight is so large that by default it is printed in HEX form, which is unwanted by the current mlir-translate openfhe emitter
heir-translate '--mlir-print-elementsattrs-with-hex-if-larger=-1' --emit-openfhe-pke --openfhe-scheme=ckks mlp_openfhe.mlir > mlp_openfhe.cpp
heir-translate --emit-openfhe-pke-header --openfhe-scheme=ckks mlp_openfhe.mlir > mlp_openfhe.h

We need the following change in mlp__generate_crypto_context() in mlp_openfhe.cpp (be careful, this file is 30M, open it in IDE may make it stuck) to make it work:

  • Delete .SetPlaintextModulus(). CKKS does not need it.
  • Add v14431.SetSecurityLevel(HEStd_NotSet); v14431.SetRingDim(1 << 11);. Change v14431 to your own value. Because our ciphertext is of size 1024, for full packing we need RingDim 2048. To meet the default security parameter, RingDim would become 32768 and things is too slow.

Compile it against mlp_openfhe_main.cpp

# you should add -I and -L for your specific OpenFHE installation
# takes around 30s to compile
clang++ -std=c++17 -o mlp_openfhe mlp_openfhe_main.cpp mlp_openfhe.cpp -I. -lOPENFHEcore -lOPENFHEpke -lOPENFHEbinfhe
# set unlimited stack size, as the weight is all in stack now.
# otherwise you will observe segfault
ulimit -s unlimited
# run the program
./mlp_openfhe

We could get the following output. Each inference takes ~30 seconds, memory usage is ~3GB.

Element Parameters: ILDCRTParams [m=4096 n=2048 q=1867228318734141816999655779671310589399405959067135205574019961160007458457284606914047149436113163649889870755858538873363336506924688056138331269655941214209 ru=0 bigq=0 bigru=0]
  m_params:
    0: ILParams [m=4096 n=2048 q=1152921504606830593 ru=459811883340678 bigq=0 bigru=0]
    1: ILParams [m=4096 n=2048 q=1125899907260417 ru=479982368344 bigq=0 bigru=0]
    2: ILParams [m=4096 n=2048 q=1125899907063809 ru=809174752721 bigq=0 bigru=0]
    3: ILParams [m=4096 n=2048 q=1125899907219457 ru=105703250448 bigq=0 bigru=0]
    4: ILParams [m=4096 n=2048 q=1125899906977793 ru=996886817494 bigq=0 bigru=0]
    5: ILParams [m=4096 n=2048 q=1125899907145729 ru=842109450467 bigq=0 bigru=0]
    6: ILParams [m=4096 n=2048 q=1125899906990081 ru=25027901798 bigq=0 bigru=0]
    7: ILParams [m=4096 n=2048 q=1125899907096577 ru=203477575998 bigq=0 bigru=0]
    8: ILParams [m=4096 n=2048 q=1125899906826241 ru=1080667890455 bigq=0 bigru=0]
    9: ILParams [m=4096 n=2048 q=1125899906949121 ru=251751765212 bigq=0 bigru=0]
    10: ILParams [m=4096 n=2048 q=557057 ru=66 bigq=0 bigru=0]


Encoding Parameters: [p=50 rootP =0 bigP =0 rootBigP =0 g=0 L=1024]

max_id: 7, label: 7
max_id: 2, label: 2
max_id: 1, label: 1
max_id: 0, label: 0
max_id: 4, label: 4
max_id: 1, label: 1
max_id: 4, label: 4
max_id: 9, label: 9
max_id: 6, label: 5
max_id: 9, label: 9
accuracy: 9/10

If we use the RingDim 32768, it takes ~6min to do one inference and memory usage is 26G.

Discussion

  • The limitation on the current Halevi-shoup matmul impl makes the thing most painful as IR is big (~20M) and CPP output is big (~30M) and stack usage is big ulimit -s unlimited. We should support load weight at runtime instead of hard-encode it in IR. The technical issue is then issue many tensor op to correctly pack the plaintext weight matrix.
  • I think for all complex example/benchmark, there should be a cleartext version (lowered to LLVM) and a homormorphic verson (lowered to backend) so we could ensure correctness (easier to debug whether it is wrong input program or wrong compiler transformation) and we may further use these two lowering to compare cleartext/homormorphic computation efficiency difference.
  • I think this example fits well under tests/Example/benchmark, I want to discuss how to integrate it as a benchmark example of HEIR.

ZenithalHourlyRate avatar Dec 24 '24 14:12 ZenithalHourlyRate

This is really a fantastic start! Thank you for being so thorough.

I want to get this work checked in, so I'm trying to figure out how to tease apart the overall thing into different work items we can make progress on.

Instead of an arith.constant, we should be able to use a memref.global and the get_global op to avoid putting all the data on the stack.

At a high level, our main goal should be to ensure the file sizes are small, runtimes are fast, and that things can be compressed appropriately. So the goal to put the weights in a separate file should mainly be to enable ease of compression. For MLIR, I think we could get most of those benefits by having the values be globals and storing the files as MLIR bytecode (not plaintext), and we could easily check in the bytecode files as test inputs. Adding bytecode is a slightly more complicated task, in that it requires versioning our IR, so I'd like to avoid that as long as possible, and maybe that means having large-ish plaintext checked-in files for a while. So long as we avoid the compile time and stack usage issues above, we can find a good solution for having large test files.

j2kun avatar Jan 06 '25 19:01 j2kun

I think for all complex example/benchmark, there should be a cleartext version (lowered to LLVM) and a homormorphic verson (lowered to backend) so we could ensure correctness

I fully support this, I'm just trying to think through how we could make this easier to do than the current method, which involves relatively complex RUN directives in lit files.

j2kun avatar Jan 06 '25 19:01 j2kun

Another subtask: we should support emitting hex values in the emitter. I would bet using emitc instead of manual emission would get this for free, but in the mean time we might be able to port the code from emitc's codegen to get the same effect.

j2kun avatar Jan 06 '25 19:01 j2kun

Thinking about this some more, could we get the input model saved in an independent format, say ONNX or stablehlo, and use that as a starting point?

j2kun avatar Jan 06 '25 20:01 j2kun

@ZenithalHourlyRate from my end, I was made the necessary modifications to preserve the loop so the resulting file isn't too slow to compile. For the constants, I was able to serialize them and write to a proto, and use a helper function to load weights into memory in the emitter. (It still loads all weights into memory at once though since the loop is preserved, and the calls to extract a slice will slice that memory).

I'll be pushing the piece-meal changes as I go.

I did not serialize as bytecode, but I will try to do that in my PRs.

asraa avatar Feb 05 '25 16:02 asraa

I tried to run the cleartext version of this from the mlp directory, but I do get an error about a missing symbol.

The command I run is: ../../../../bazel-out/darwin_x86_64-dbg/bin/tools/heir-opt --heir-polynomial-to-llvm $PWD/mlp.mlir | ../../../../bazel-out/darwin_x86_64-dbg/bin/external/llvm-project/mlir/mlir-translate --mlir-to-llvmir | ../../../../bazel-out/darwin_x86_64-dbg/bin/external/llvm-project/llvm/llc --relocation-model=pic -o mlp.s | clang++ -o mlp_main mlp.cpp mlp.s -L../../../../bazel-out/darwin_x86_64-dbg/bin/external/llvm-project/mlir/ -lmlir_c_runner_utils -lmlir_runner_utils -fPIE

and the error I get is:

Undefined symbols for architecture x86_64:
  "_memrefCopy", referenced from:
      _mlp in mlp-ea5cac.o
ld: symbol(s) not found for architecture x86_64

Could this be a similar error as in #1521?

kragall avatar Mar 21 '25 08:03 kragall

-lmlir_c_runner_utils and -lmlir_runner_utils should already provide the symbol memrefCopy. I did not know how MacOS toolchain mangled the name into _memrefCopy. I think that should be a similar problem as in #1521

ZenithalHourlyRate avatar Mar 21 '25 10:03 ZenithalHourlyRate

The solution of hieunch from #1521 helps here, too, and the command now runs without errors.

kragall avatar Mar 26 '25 14:03 kragall

I tried to run the example using heir-opt --mlir-to-secret-arithmetic mlp_inline.mlir. At least on my machine there's a core dump happening at the layout-propagation pass.

Assertion failed: (Index < Length && "Invalid index!"), function operator[], file ArrayRef.h, line 253.
...
#8 0x0000000104394604 llvm::ArrayRef<long long>::operator[](unsigned long) const (/private/var/tmp/_bazel_mar69689/946252ae710c30be6410a6296d2b67c1/execroot/_main/bazel-out/darwin_arm64-dbg/bin/tools/heir-opt+0x100004604)
#9 0x0000000104f5e8c8 mlir::heir::LayoutPropagation::visitOperation(mlir::tensor::CollapseShapeOp) (/private/var/tmp/_bazel_mar69689/946252ae710c30be6410a6296d2b67c1/execroot/_main/bazel-out/darwin_arm64-dbg/bin/tools/heir-opt+0x100bce8c8)
...

I know there's work going on to get a new layout system working, so this might not be an issue once the new layout system is finished.

kragall avatar Aug 21 '25 13:08 kragall

I know there's work going on to get a new layout system working, so this might not be an issue once the new layout system is finished.

I'm afraid the old .mlir file is and will no longer be compatible with the layout system. One option is to rollback to the commit of last year to reproduce it, see https://github.com/ZenithalHourlyRate/heir/commits/mlp-ckks/

ZenithalHourlyRate avatar Aug 21 '25 13:08 ZenithalHourlyRate

We have MNIST lowering now through HEIR natively. Pending a few minor cleanups, so closing this issue

j2kun avatar Sep 17 '25 06:09 j2kun