orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Orbax provides common utility libraries for JAX users.

Results 335 orbax issues
Sort by recently updated
recently updated
newest added

Hi, I am receiving this warning `WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before May 1st, 2024. If your Pytree has empty ([],...

Reproduction code: ```python class PyTreeDict(dict): pass jax.tree_util.register_pytree_node( PyTreeDict, lambda d: (tuple(d.values()), tuple(d.keys())), lambda keys, values: PyTreeDict(dict(zip(keys, values))) ) a={"a": PyTreeDict()} # ValueError: Expected dict, got {}. # a=PyTreeDict() # ValueError:...

Add an option to allow sharding consolidation / reduction to work around the dtensor limitation of disallowing one dimension being sharded across multiple axis names. This option is disabled by...

Create the utils function load_model_with_tfrt API for XLA devices.

Are there any plans to support this in orbax? Tensorstore can interpret strings: https://google.github.io/tensorstore/python/api/tensorstore.string.html I realize I can pull out the object arrays into a json file and then stitch...

Enable Orbax checkpointing with background deleting thread