orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
#v1 Add `LeafHandler` as a `CheckpointableHandler`, so that ordinary PyTree leaves can also be saved as individual checkpointables.
When running the latest `orbax-export`, the `from orbax.export.protos import oex_orchestration_pb2` line in utils.py is failing with an `ImportError: cannot import name 'oex_orchestration_pb2' from 'orbax.export.protos'`. Here is a test log: https://github.com/jax-ml/jax-ai-stack/actions/runs/17402977130/job/49400221219.
Internal change.
Internal adjustment to tree-verity reporting
Migrate Tunelab to Orbax `PreservationPolicy` for checkpoint management.
Seems like there was an oversight in the case where HNS is enabled, but no `step_prefix` is provided. In that case, `step_prefix` is None, which results in an error when...
`manager.restore` now returns a `orbax.checkpoint._src.handlers.composite_checkpoint_handler` object instead of a PyTree, so we need to access the `state` member of the returned object.
Internal change
Internal change
Update sidv2, yt_prod_v3 test data and re-enable the test.