keras-nlp icon indicating copy to clipboard operation
keras-nlp copied to clipboard

[OpenVINO backend] supporting inference for gemma with ov backend

Open Mohamed-Ashraf273 opened this issue 5 months ago • 3 comments

Description of the change

As a part of my GSoC25 project to support inference with the openvino backend for Gemma, This is my PR for supporting the Gemma pipeline.

Reference

https://docs.openvino.ai/2025/index.html https://keras.io/api/ https://keras.io/keras_hub/

Colab Notebook

Checklist

  • [x] I have added all the necessary unit tests for my change.
  • [x] I have verified that my change does not break existing code and works with all backends (TensorFlow, JAX, and PyTorch).
  • [x] My PR is based on the latest changes of the main branch (if unsure, rebase the code).
  • [x] I have followed the Keras Hub Model contribution guidelines in making these changes.
  • [x] I have followed the Keras Hub API design guidelines in making these changes.
  • [x] I have signed the Contributor License Agreement.

Mohamed-Ashraf273 avatar Jun 22 '25 15:06 Mohamed-Ashraf273

Left some initial comments! But probably first question is around the changes to causal_lm and gemma_causal_lm. Why is this so backend specific? This is much more involved than changes for jax/torch/tensorflow

hi @mattdangerw , I’ve been working on enabling inference for Gemma with OpenVINO by implementing the missing operations. The main challenge I ran into is that building the entire graph as a single model makes execution difficult (takes too long + RAM may overflow) , I'm still investigating why. To address this, I introduced a subgraph-based approach by splitting the full graph at key points. I also added logic to store compiled subgraphs in CausalLM so they can be reused across generations instead of rebuilding them each time.

Mohamed-Ashraf273 avatar Jun 24 '25 20:06 Mohamed-Ashraf273

@Mohamed-Ashraf273 is there a way that we can land this without the subgraph approach?

We have a similar need in Jax at train time. Compilation times are much improved if you run a common transformer block in a compiled loop. So probably there is something to do here, but we'd really like to avoid our forward pass being a switch case on backend here. That will lead to maintenance hell.

So maybe let's try to land with the same approach as other backends for now, and see if there's a layer stacking/compilation reuse solution we can land as follow up?

mattdangerw avatar Jun 25 '25 18:06 mattdangerw

hi @mattdangerw I removed the subgraph approach + removed reusing part. Now the model can be inferred with OpenVINO and pass all tests, we just need to think about how to optimize inference with large parameters and real weights without RAM overflow. I'd appreiciate tit if you can take another look. Thanks!

Mohamed-Ashraf273 avatar Jun 26 '25 06:06 Mohamed-Ashraf273

@mattdangerw My PR is ready for review!

Mohamed-Ashraf273 avatar Jun 28 '25 13:06 Mohamed-Ashraf273

@fchollet Can you take a look?

Mohamed-Ashraf273 avatar Jul 01 '25 09:07 Mohamed-Ashraf273

@mattdangerw

Mohamed-Ashraf273 avatar Jul 04 '25 15:07 Mohamed-Ashraf273

Hi @fchollet , I'd appreciate any feedback on my PR. thanks

Mohamed-Ashraf273 avatar Jul 09 '25 13:07 Mohamed-Ashraf273

/gemini review

divyashreepathihalli avatar Jul 10 '25 23:07 divyashreepathihalli

@fchollet @mattdangerw @rkazants @divyashreepathihalli

Mohamed-Ashraf273 avatar Jul 14 '25 16:07 Mohamed-Ashraf273

@mattdangerw @divyashreepathihalli

Mohamed-Ashraf273 avatar Jul 21 '25 19:07 Mohamed-Ashraf273