orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Rename the concrete class `Array` to `ArrayImpl`
Use OrbaxCheckpointManager in Pax training loop when orbax flag is enabled.
Add OrbaxCheckpointManager in Pax which aligns with Orbax APIs. A few overrides are needed for full compatibility with Pax checkpoint format and extra features.
Use flag-guarded Orbax APIs to implement save/restore GDA checkpoint.
Internal changes.
Add Checkpointer/CheckpointManager implementations that rely on t5x.checkpoints.Checkpointer to execute save and restore. This refactor should not change behavior. This represents an intermediate stage before creation of more generic APIs not...
See this code sample: ```python import pathlib import orbax.checkpoint orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer() tree = {'key': 7} path = pathlib.Path(__file__).parent / 'foo' orbax_checkpointer.save(path, tree) ``` When I run it on Linux,...
Ensure OCDBT driver does not exceed file limit
Parallelize many directory creations using asyncio. Also reduce the number of `asyncio.run` calls by moving `async` functions higher in the stack.
Add logging to track deprecated codepaths.