orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Add bfloat16 flags and support bfloat16 optimization for native JAX function.
Dear all, I wonder if it would be possible to support saving and loading zero size arrays. I saw it was deactivated in https://github.com/google/orbax/pull/1570. Allowing it would be very useful...
Introduce an enum `CheckpointingImpl` to simplify Pathways registration and to allow multiple options to be specified and resolved in order of priority.
Introduce ArrayReadSpec, replacing get_json_tspec_read usage
Always write array metadata if self._array_metadata_store is not None.
Internal change.
Add `checkpoint_manager_test.py` to multiprocess tests.
Update checkpoint manager benchmark to use PyTree args and add deletion.
Internal Change
## Summary When using `nnx.List` (which stores items with integer dict keys) with Orbax checkpointing, integer keys are converted to strings during save but not converted back during restore. This...