prompt-tuning icon indicating copy to clipboard operation
prompt-tuning copied to clipboard

Checkpointing of a T5 model results in serialization error with new Jax 0.4.5

Open emersodb opened this issue 1 year ago • 2 comments

Flax has removed optim in favor of optax in its newest versions above 0.5.3. This means that in order to run the code in this repository, one needs to downgrade below Flax 0.6. However, if you do that with Jax 0.4.5 or even with Jax 0.3.25 + jax.config.update('jax_array', True), the code cannot save a model checkpoint due to msgpack being unable to serialize the jax arrays.

Expected Behavior

Model should be able to be saved as a checkpoint.

Actual Behavior

image

Steps to Reproduce the Problem

With Jax 0.4.5 and Flax 0.5.3 one can minimally recreate this issue in a python repl as

image

emersodb avatar Mar 11 '23 17:03 emersodb