orbax icon indicating copy to clipboard operation
orbax copied to clipboard

New interface does not support custom empty pytree class inherited from dict

Open ZaberKo opened this issue 1 year ago • 1 comments

Reproduction code:

class PyTreeDict(dict):
    pass

jax.tree_util.register_pytree_node(
    PyTreeDict,
    lambda d: (tuple(d.values()), tuple(d.keys())),
    lambda keys, values: PyTreeDict(dict(zip(keys, values)))
)

a={"a": PyTreeDict()} # ValueError: Expected dict, got {}.
# a=PyTreeDict() # ValueError: Found empty item

path = ocp.test_utils.erase_and_create_empty('./debug').resolve()/'ckpt'
ckpt.save(path, a)
ckpt.restore(path, args=ocp.args.StandardRestore(a))

This issue is related to #720 and https://github.com/google/orbax/commit/a066d9c83047185c15d29d88eff989d3101b8136. @niketkumar

ZaberKo avatar Apr 13 '24 19:04 ZaberKo

Thanks for reporting, we're looking some refactoring that will resolve these empty node issues.

cpgaffney1 avatar Apr 19 '24 15:04 cpgaffney1