keras
keras copied to clipboard
Backend-agnostic JIT compilation support for non supervised learning workflows
Hello!
Keras 3 provides a very nice backend-agnostic way to enable/disable eager mode & JIT compilation for Supervised Learning workflows: adjusting run_eagerly & jit_compile in model.compile() before calling model.fit().
However, for other workflows like Reinforcement Learning where typically the model is called directly (e.g. model(state) in this official example), it seems that we lose this nice capability - or am I missing something?
It would be great to have a way to support it without having to re-write some boilerplate code. I believe that model.predict_on_batch() almost does the job (since it seems to take into account the jit_compile value set in model.compile()), except that it currently doesn't allow to keep a tensor output (instead of casting it to Numpy) and to enable gradient computation.
Any recommendation/idea to make alternative workflows work smoothly with eager/JIT whatever the backend would be very welcome. Thanks!
I don't think this will give great performance boost for the example you have mentioned when you explicitly enable JIT.
If you implement jit_compile, it may cause gradient recording issues, you need to be careful while designing.
Thank you for your reply, but I am not sure to get your point about "gradient recording issues". In my understanding, jit-compilation is the new default in Keras 3 (at least for JAX/TF backends) when training with model.fit() -which uses of course gradient descent for learning- and it results in a dramatic training time reduction.
I wish we had the same capability to also enable it in a cross-backend way when just calling the model in a training or inference loop (without using model.fit()).
We still need a decorator fn in keras that does jax.jit, tf.function(jit_compile=True) etc... under the hood depending on backend.
In downstream code, ideally we don't have to write any backend specific code and need keras to be the single api to accelerate helper functions that are not strictly part of models or layer but necessary in training or inference pipeline.
Something like this:
@keras.jit(unified api)
def some_fn_(x, y, z):
...
Also somewhat related, when I do just install jax[cuda] , still getting warning print outs like this, though code :
tests/test_dataset_ops.py INFO:2025-04-28 11:23:31,648:jax._src.xla_bridge:867: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
No tpu ops were requested