keras icon indicating copy to clipboard operation
keras copied to clipboard

Tensorflow is still imported when using jax or torch backend

Open ageron opened this issue 2 years ago • 4 comments

When I import keras_core, it imports TensorFlow even when I set the backend to jax or torch:

>>> import os
>>> os.environ["KERAS_BACKEND"] = "jax"
>>> import keras_core
2023-06-12 11:14:46.431809: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Using JAX backend.
>>> keras_core.backend.backend()
'jax'

Since TensorFlow takes up to 3-4 seconds to load on my machine, so it would be nice to avoid that. And of course it would be nice not to have to install it when using another backend since it's quite a big beast and uses a lot of disk space.

ageron avatar Jun 11 '23 23:06 ageron

I agree that we need to remove the TF required dependency. There are several reasons why this is not possible right now:

  • We use tf.nest for processing nested Python structures. First, we'd need to extract nest to a standalone Python library.
  • We use tf.data for data preprocessing in various places. We could make this optional though.
  • We use TF to implement preprocessing layers such as lookup layers, hashing layers, etc. There is no other realistic way to implement these ops, short of doing a rewrite in e.g. Rust.

Eventually we'll make TF optional, but it will take some time.

fchollet avatar Jun 11 '23 23:06 fchollet

is tf.nest still used to process nested Python structures ?

vulkomilev avatar Sep 05 '23 16:09 vulkomilev

Hi @ageron -

In keras3, for deeply nested inputs in functional models no need to use tf.nest. You can directly apply dictionary input or nested dictionary(more than 1 level) also applied as input to model. Here you can find more detail about it.

inputs = {
    "foo": keras.Input(shape=(1,), name="foo"),
    "bar": {
        "baz": keras.Input(shape=(1,), name="bar"),
    },
}
outputs = inputs["foo"] + inputs["bar"]["baz"]
keras.Model(inputs, outputs)

This nested input works fine with JAX and torch backend as well. Attached gist for your reference.

mehtamansi29 avatar Sep 13 '24 18:09 mehtamansi29

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

github-actions[bot] avatar Sep 28 '24 02:09 github-actions[bot]

This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.

github-actions[bot] avatar Oct 13 '24 02:10 github-actions[bot]

Are you satisfied with the resolution of your issue? Yes No

google-ml-butler[bot] avatar Oct 13 '24 02:10 google-ml-butler[bot]