orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Incompatibility with Haiku

Open chriscarmona opened this issue 1 year ago • 4 comments

Reopening an issue regarding incompatibility with Haiku naming conventions (similar to previous issue). This is not problematic in v0.3.5

Sample code:

from jax import numpy as jnp
import orbax.checkpoint as ocp
import haiku as hk


@hk.transform
def forward_fn(inputs):
  # net = hk.Linear(output_size=2) # This works
  net = hk.nets.MLP(
      output_sizes=[2, 2], activate_final=True)  # This doesn't work
  return net(inputs)


prng_seq = hk.PRNGSequence(0)
params = forward_fn.init(next(prng_seq), jnp.ones((1, 5)))

ckpt_dir = '/tmp/my-checkpoints/'
orbax_mngr = ocp.CheckpointManager(
    ckpt_dir,
    {'state': ocp.PyTreeCheckpointer()},
    options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
orbax_mngr.save(step=0, items={'state': params})

The error:

Traceback (most recent call last):
  File "/workspaces/modularbayes/examples/bar.py", line 23, in <module>
    orbax_mngr.save(step=0, items={'state': params})
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 496, in save
    self._checkpointers[k].save(item_dir, item, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 79, in save
    self._handler.save(tmpdir, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 818, in save
    asyncio.run(async_save(directory, item, *args, **kwargs))
  File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
    return future.result()
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 811, in async_save
    commit_futures = await self.async_save(*args, **kwargs)  # pytype: disable=bad-return-type
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 786, in async_save
    commit_futures = await asyncio.gather(*serialize_ops)
  File "/home/ubuntu/.local/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 893, in serialize
    open_future = ts.open(
ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/mlp/~/linear_0.b" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']
sys:1: RuntimeWarning: coroutine 'async_serialize' was never awaited

chriscarmona avatar Oct 02 '23 15:10 chriscarmona

Hi Chris,

Thanks for raising the issue. We have submitted a fix and this issue should have been resolved now. Can you please verify that this is no longer erroring out for you? Thanks!!

Best, Yaning

liangyaning33 avatar Oct 10 '23 16:10 liangyaning33

@liangyaning33

I installed with

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

but I get the error

ValueError: Error parsing object member "json_pointer": JSON Pointer requires '~' to be followed by '0' or '1': "/state.disc.batch_norm/~/mean_ema.average" [source locations='tensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']

Version 0.3.5 works.

Carbon225 avatar Oct 20 '23 17:10 Carbon225

Hi Chris,

Do you mind checking what version of the orbax checkpointing you are using? I just checked this in notebook with v0.4.1 and it works. Screenshot 2023-10-24 at 10 20 09 AM

liangyaning33 avatar Oct 24 '23 17:10 liangyaning33

Hi @liangyaning33,

Apologies for the delayed response. I observe the same behaviour as @Carbon225, The error still appears in v0.4.1 (as currently installed by pip install -U orbax-checkpoint). It works with v0.3.5

chriscarmona avatar Oct 26 '23 17:10 chriscarmona