orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions? ================================================================== # what happened When I save...
Fix errors of loading ckpt that contains sharding configs.
orbax change is from diffbase
How to restore a variable from checkpoint saved in cpu back in cpu when you have both gpu and cpu?
I get the following error when I try to restore, **ValueError: SingleDeviceSharding with Device=TFRT_CPU_0 was not found in jax.local_devices().** Despite enclosing the statements within a CPU device scope, like below,...
Hello, This is somewhat similar to #646. During training, I saved my parameters in a sharded manner (could not use aggregate because they were sharded over multiple hosts). Now I...
Hi, I made checkpoints of my weights and some Jax arrays using the following snippet on a GPU instance (v100): ``` import orbax.checkpointer as ocp checkpointer = ocp.PyTreeCheckpointer() checkpointer.save(path, checkpoint)...
Hi, we are trying out the orbax (0.4.1) AsyncCheckpointer (used through CheckpointManager). We are getting "Array has been deleted" errors. It seems as if the async checkpointer is trying to...
Provide better support for custom `CheckpointHandler`s without registered `CheckpointArgs` by providing a wrapper `CheckpointHandler` as a fallback. This class is introduced for backwards compatibility, and will eventually be removed.
Internal.
The internal changes.