orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Introduce `ListKey` and `TupleKey` for `list` and `tuple` PyTree nodes respectively.
Add Metadata Checkpointer for PWoC.
Support nested `asyncio.run` without `nest_asyncio` library.
Add monitoring to record how long directory creation takes (per sequential instance).
Add `blocking_metadata_write` option to allow disabling async metadata write behavior.
Restore from local with the mutated mesh, and transfer arrays to get back Pytree with original mesh before broadcasting.
I have been trying to use Orbax for checkpointing Flax NNX models and getting checkpointing to work for models with Dropout layers which also hold JAX RNG keys is not...
Add emergency checkpoint logging of arrays for debugging.
Bumps the pip group with 1 update in the /docs/requirements directory: [ipython](https://github.com/ipython/ipython). Updates `ipython` from 7.23.1 to 8.10.0 Commits 15ea1ed release 8.10.0 560ad10 DOC: Update what's new for 8.10 (#13939)...
test restoring to global_mesh in parallel