orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Internal change
Replace deprecated gda_serialization with array_serialization.
Add method for fetching most-recent checkpoints
Log the completion of renaming The extra logging marks the completion of renamining. This logging can be helpful for understanding the cost of rename operation.
Add OCDBT support in Pax. If the feature is turned on, legacy checkpoints will be readable, but new checkpoints will be saved using OCDBT. A checkpoint version change is not...
Enable checkpointing for training inputs in Pax (using Orbax).
[JAX] Fix confusion between a pytree and a PyTreeDef in orbax. This causes pytype errors under an upcoming JAX change.
Refactor Flax checkpointing to closer integrate with Orbax class structure. We can duplicate Flax's behavior using Orbax.
Fix incorrect usage of `is_checkpoint_finalized` in Checkpointer. The bug was reported when restoring a checkpoint on GCS, the checkpoint was always thought to be "not finalized".
Support jax.Array in T5X. This feature is currently disabled but can be enabled by setting `use_gda=True` and `use_jax_array=True`.