orbax
orbax copied to clipboard
Orbax provides common utility libraries for JAX users.
Hi, I am receiving this warning `WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([],...
Reproduction code: ```python class PyTreeDict(dict): pass jax.tree_util.register_pytree_node( PyTreeDict, lambda d: (tuple(d.values()), tuple(d.keys())), lambda keys, values: PyTreeDict(dict(zip(keys, values))) ) a={"a": PyTreeDict()} # ValueError: Expected dict, got {}. # a=PyTreeDict() # ValueError:...
remove `all_params_aggregated`.
Modify JsonCheckpointHandler to Async.
Add an option to allow sharding consolidation / reduction to work around the dtensor limitation of disallowing one dimension being sharded across multiple axis names. This option is disabled by...
Create the utils function load_model_with_tfrt API for XLA devices.
Are there any plans to support this in orbax? Tensorstore can interpret strings: https://google.github.io/tensorstore/python/api/tensorstore.string.html I realize I can pull out the object arrays into a json file and then stitch...
Internal change.
Internal change.
Enable Orbax checkpointing with background deleting thread