orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Incorrect null check in pytree_checkpoint_handler.py

Open hr0nix opened this issue 2 years ago • 4 comments

orbax/checkpoint/pytree_checkpoint_handler.py:661 has the following check: if not item

It most likely should be if item is None, as otherwise this check will raise an error when item is an array (which is a valid pytree according to pytree definition).

hr0nix avatar Aug 27 '23 14:08 hr0nix

Please use ArrayCheckpointHandler instead of PyTreeCheckpointHandler to handle singular arrays.

cpgaffney1 avatar Aug 28 '23 15:08 cpgaffney1

I think the distinction being made is between PyTree containers and leaves. From the JAX docs: "By default, pytree containers can be lists, tuples, dicts, namedtuple, None, OrderedDict. Other types of values, including numeric and ndarray values, are treated as leaves". To save singular jax.Array or np.ndarray, ArrayCheckpointHandler is provided instead. Obviously there is some grey area, but I'm hesitant to cram yet more functionality into PyTreeCheckpointHandler - it has become bloated enough as it is, and we're pushing for simplification in a few key aspects.

cpgaffney1 avatar Aug 28 '23 15:08 cpgaffney1

Fair enough. In any case, the error was really obscure and the error message in this case can be improved.

hr0nix avatar Aug 28 '23 15:08 hr0nix

Sure that's also fair, I'll make a TODO.

cpgaffney1 avatar Aug 28 '23 21:08 cpgaffney1