keras-io
keras-io copied to clipboard
jit_compile control in `BaseImageAugmentationLayer`
This is just a follow-up of https://github.com/keras-team/keras-cv/issues/165#issuecomment-1083502165
@qlzh727 What do you think about adding an extra parameter to the base class for jit_compile? https://github.com/keras-team/keras/blob/master/keras/layers/preprocessing/image_preprocessing.py#L413-L414
So that we could optionally use something like:
f = def_function.function(self._augment, jit_compile=True)
self._map_fn(f, inputs)
Triage notes:
- Such an option should likely not be exposed to end users, so it doesn't need to be an argument.
- Some layers may require
jit_compilein order to be performant. In such cases we should just make thetf.function(jit_compile=True)directly part of the layer implementation
What was the logic to expose this in model compile instead?
I suppose that if we already let the user to jit_compile or not in the model compile API we don't want to automatically compile layers without any user control.
https://github.com/keras-team/keras/blob/39ad2c1cb22b231baf05a0218322328c13654bda/keras/engine/training.py#L532
/cc @qlzh727 @LukeWood
I suppose that we will have a small "explosion" of XLA jit_compile failures when we will enable the XLA compilation.
And they will be more fatal then the ones we have in https://github.com/keras-team/keras-cv/issues/291 for tf.vectorized_map.
tf.vectorized_map has an auto-slowdown fallback effect but XLA instead has a fail fast policy so the first TF2XLA not implemented op that we use in a layer implementation it will go to totally break the jit compilation.
A sort of "fallback" is something different in XLA and it is light outside (GPU only) that it require to be implemented for every op in TF2XLA that you use in your implementation (then a CPU/TPU HLO implementation it is still required if you want to jit_compile on these devices).
As I am quite brand new to XLA internals /cc @cheshire in the case he want to add some advise.
@fchollet @qlzh727 Can you migrate this to keras-cv now? It seems @LukeWood has not enough rights in this repo for the migration.
From Keras 3, jit_compile is set to auto in the model.compile, which means it will use XLA if the model allows it.
Can we close the issue, considering the Keras 3 implementation.
Is it binary for the library user? e.g. whole model compile or nothing?
From the doc: https://keras.io/api/models/model_training_apis/
jit_compile: Bool or "auto". Whether to use XLA compilation when compiling a model. For jax and tensorflow backends, jit_compile="auto" enables XLA compilation if the model supports it, and disabled otherwise. For torch backend, "auto" will default to eager execution and jit_compile=True will run with torch.compile with the "inductor" backend.
But here we were talking about BaseImageAugmentationLayer not the whole model interface. In any case I don't know all the new refactors so do what you want with this ticket.
BaseImageAugmentationLayer subclasses Keras Layer class and Keras Layer has an argument for operating in jit by default using self.supports_jit = True
Below are the code references.
https://github.com/keras-team/keras/blob/7ce3d62af7cc6959fc5a5841cfe17043dfcb8615/keras/layers/layer.py#L275
https://github.com/keras-team/keras/blob/7ce3d62af7cc6959fc5a5841cfe17043dfcb8615/keras/layers/layer.py#L275
Please check and close the issue if there is no question. Thanks
As self.supports_jit it is a bool how we control XLA vs torch.compile dynamo VS other compilers?
A layer that could be compiled for one stack doesn't meant that it could be compiled with another backend.