keras icon indicating copy to clipboard operation
keras copied to clipboard

ExportArchive cannot be used with JAX

Open martin-gorner opened this issue 1 year ago • 0 comments

Repro Colab: https://colab.research.google.com/drive/1TzvwkSY_EteBQuhNjki2i3Kv37EPfY-Y?usp=sharing

The use case for using ExportArchive with a JAX model is when using jax2tf manually.

Manual use of jax2tf is an important CUJ for two reasons:

  • it is the only way to specify polymorphic shapes
  • it is the only way to include TF pre- and postprocessing ops in the inference function. JAX does not have similar ops

Unfortunately, there doe not seem to be any way to go through this CUJ at the present time.

Requested fix: It should be possible to create a TF function from a JAX model through jax2tf and export it using ExportArchive

Note: The wrapping of JAX variables in tf.Variable before using jax2tf is necesssary as per the jx2tf documentation. Without it, saving through ExportArchive is possible but results in all variables being saved in the graph as constants which is very inefficient.

martin-gorner avatar Feb 01 '24 22:02 martin-gorner