orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Orbax provides common utility libraries for JAX users.

Results 335 orbax issues
Sort by recently updated
recently updated
newest added

Add user-configured callback to async_checkpointer.

Delegate to BasePyTreeCheckpointHandler rather than inheriting from it.

Hi Orbax community, Under the hood, Orbax uses [TensorStore](https://google.github.io/tensorstore) for tensor IO, TensorStore integration is a part of [type_handlers.py](https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/type_handlers.py). TensorStore comes with KVStore implementations for [File](https://google.github.io/tensorstore/kvstore/file/index.html), [GCS](https://google.github.io/tensorstore/kvstore/gcs/index.html), [S3](https://google.github.io/tensorstore/kvstore/s3/index.html), [GRPC](https://google.github.io/tensorstore/kvstore/tsgrpc/index.html), etc....

type:feature
checkpoint

Fork a small amount of Orbax code into Pax to deal with writing "aggregate" files, as Orbax will soon lose this ability.

Replace references to deprecated device_buffer attributes `jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see https://github.com/google/jax/pull/18844.

Hello, I am trying to load metadata on a new device from a checkpoint via `CheckpointManager` API, but somehow struggle to find a solution. Below is a minimal example of...

Add try catch block to catch ValueError reading from sharding string.

I'm trying to checkpoint the flax's TrainState in the distributed setup, where each node has an access to multiple devices: ``` def save_checkpoint(args, state, step): state = unreplicate(state) # flax.jax._utils.unreplicate...