Provide an efficient inference implementation using sparsification/quantization
Goal: reduce inference time of the model using quantization
We made some CPU inference performance results public for 2021 in CMS, https://cds.cern.ch/record/2792320/files/DP2021_030.pdf slide 16, “For context, on a single CPU thread (Intel i7-10700 @ 2.9GHz), the baseline PF requires approximately (9 ± 5) ms, the MLPF model approximately 320 ± 50 ms for Run 3 ttbar MC events”.
Now it's a good time to make the inference as fast as possible, while minimizing any physics impact.
Resources:
- https://www.youtube.com/playlist?list=PL80kAHvQbh-pT4lCkDT53zT8DKmhE0idB
- https://www.tensorflow.org/model_optimization/guide
- https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html
- https://github.com/fastmachinelearning/qonnx
- https://github.com/calad0i/HGQ
- https://arxiv.org/pdf/2307.02973.pdf
adding @raj2022
Also related: https://github.com/jpata/particleflow/issues/315
Basically, to summarize:
- with @raj2022 we saw that it's possible to quantize the model to int8 in pytorch using post-training stating quantization, following the recipe in https://github.com/jpata/particleflow/blob/main/notebooks/clic/mlpf-pytorch-transformer-standalone.ipynb
- the important features were a custom attention layer (in the notebook), and introducing per-feature quantization stubs
- we also showed that using just relu, it's possible to train a very performant model, therefore this work improved the compute budget
- however, the int8 exported model was not faster neither on CPU nor on GPU
- this most likely requires a more informed approach to make sure the int8 attention is actually computed using efficient ops on the hardware
- the summary notebook was added in https://github.com/jpata/particleflow/pull/297
- ONNX may be a better path for performant quantization in the end, but this requires more study.
I'm closing this issue, and putting it on the roadmap to study ONNX post-training static quantization separately. Many thanks to @raj2022 for your contributions!