orbax
orbax copied to clipboard
Incompatibility with Haiku
Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5
Sample code:
from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk
@hk.transform
def forward_fn(inputs):
# net = hk.Linear(output_size=2) # This works
net = hk.nets.MLP(
output_sizes=[2, 2], activate_final=True) # This doesn't work
return net(inputs)
prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))
ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
ckpt_dir,
{'state': ocp.PyTreeCheckpointer()},
options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})
The error:
Traceback (most recent call last):
File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
orbax_mngr.save(step=0, items={'state': params})
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
self._checkpointers[k].save(item_dir, item, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
self._handler.save(tmpdir, *args, **kwargs)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
asyncio.run(async_save(directory, item, *args, **kwargs))
File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
commit_futures = await self.async_save(*args, **kwargs) # pytype: disable=bad-return-type
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
commit_futures = await asyncio.gather(*serialize_ops)
File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited
Hi Chris,
Thanks for raising the issue. We have submitted a fix and this issue should have been resolved now. Can you please verify that this is no longer erroring out for you? Thanks!!
Best, Yaning
@liangyaning33
I installed with
pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'
but I get the error
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/state.disc.batch_norm/~/mean_ema.average" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
Version 0.3.5 works.
Hi Chris,
Do you mind checking what version of the orbax checkpointing you are using? I just checked this in notebook with v0.4.1 and it works.
Hi @liangyaning33,
Apologies for the delayed response. I observe the same behaviour as @Carbon225, The error still appears in v0.4.1 (as currently installed by pip install -U orbax-checkpoint
).
It works with v0.3.5