orbax icon indicating copy to clipboard operation
orbax copied to clipboard

`ocp.tree.serialize_tree` filtering logic for sequences with empty leaves

Open JesseFarebro opened this issue 1 year ago • 1 comments

Hi,

I spotted ocp.tree.serialize_tree but it seems the serialization logic won't work if you have empty leaves within a sequence. This happens quite frequently with optax where you'll end up with optax.EmptyState() within a tuple. Here's a minimal reproduction of this issue:

import optax
import orbax.checkpoint as ocp

tree = (0, optax.EmptyState(), 1) # or None, etc.
ocp.tree.serialize_tree(tree)

resulting in:

  File .../orbax/checkpoint/tree/utils.py", line 79, in _extend_list
    assert idx <= len(ls)
           ^^^^^^^^^^^^^^
AssertionError

I'm not sure what the ideal solution here is, I don't have enough context on what's the intended purpose of serialize_tree and deserialize_tree.

JesseFarebro avatar Nov 19 '24 07:11 JesseFarebro

This function is typically used internally. Could you clarify what you're trying to achieve?

ChromeHearts avatar Nov 20 '24 20:11 ChromeHearts