orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Add user-configured callback to async_checkpointer.
Delegate to BasePyTreeCheckpointHandler rather than inheriting from it.
Incorporate async-compatible, process-subset-compatible barrier function. With this API, barrier names must be completely unique, and many places must be adjusted to take this into account. There is also an issue...
Hi Orbax community, Under the hood, Orbax uses [TensorStore](https://google.github.io/tensorstore) for tensor IO, TensorStore integration is a part of [type_handlers.py](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/type_handlers.py). TensorStore comes with KVStore implementations for [File](https://google.github.io/tensorstore/kvstore/file/index.html), [GCS](https://google.github.io/tensorstore/kvstore/gcs/index.html), [S3](https://google.github.io/tensorstore/kvstore/s3/index.html), [GRPC](https://google.github.io/tensorstore/kvstore/tsgrpc/index.html), etc....
Fork a small amount of Orbax code into Pax to deal with writing "aggregate" files, as Orbax will soon lose this ability.
Replace references to deprecated device_buffer attributes `jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see https://github.com/google/jax/pull/18844.
Hello, I am trying to load metadata on a new device from a checkpoint via `CheckpointManager` API, but somehow struggle to find a solution. Below is a minimal example of...
Add try catch block to catch ValueError reading from sharding string.
Add OCDBT support to GlobalHostArray type_handler
I'm trying to checkpoint the flax's TrainState in the distributed setup, where each node has an access to multiple devices: ``` def save_checkpoint(args, state, step): state = unreplicate(state) # flax.jax._utils.unreplicate...