distrax
distrax copied to clipboard
Compatibility with Keras 3
I've been trying to migrate my TF=2.15 and TFP=0.23 code to Keras 3 with JAX (+TF=2.18 tf.data) and stumbled upon Distrax which seems like an elegant drop-in solution :)
However, I've noticed that importing Distrax imports TFP via from tensorflow_probability.substrates import jax as tfp which then asks for tf-keras, i.e. Keras 2.
Failed to import TF-Keras.
Please note that TF-Keras is not installed by default when you install TensorFlow Probability.
This is so that JAX-only users do not have to install TensorFlow or TF-Keras.
To use TensorFlow Probability with TensorFlow, please install the tf-keras or tf-keras-nightly package.
This can be be done through installing the tensorflow-probability[tf] extra.
Does that imply that Distrax is therefore incompatible with Keras 3?