orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
internal change
Internal.
Internal change.
Create JaxModule Protocol class. So we can separate the interface and Implementation.
SAX implementation of gemini-phase2 MMVisionModel.
Add TypeHandlers to regeister default SaveArg/RestoreArgs. This allows users to save with `Checkpointer.save(..., save_arg=None)` (or restore with `restore_arg=None`), and have the argument automatically be converted to the default `SaveArg`/`RestoreArg` for...
Pass `int` to grain seed
`orbax/checkpoint/pytree_checkpoint_handler.py:661` has the following check: `if not item` It most likely should be `if item is None`, as otherwise this check will raise an error when item is an array...
I'm benchmarking loading a 65B sharded transformer model on multiple GPUs on the same host. The checkpoint itself is not sharded, but when the model is being loaded, a correct...
Add retention policy for pytreecheckpointhandler metrics