orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Can't checkpoint Numpy RNG state

Open rainx0r opened this issue 1 year ago • 6 comments

Hi. So I'm trying to use orbax-checkpoint to checkpoint my full experiment state, not just my network weights, and part of this state is NumPy RNG states, which look as follows:

>>> import numpy as np
>>> rng = np.random.default_rng(42)
>>> state = rng.__getstate__()
>>> state
{'bit_generator': 'PCG64', 'state': {'state': 274674114334540486603088602300644985544, 'inc': 332724090758049132448979897138935081983}, 'has_uint32': 0, 'uinteger': 0}

When trying to checkpoint this with orbax-checkpoint, I get the following error:

import orbax.checkpoint as ocp

dir = ocp.test_utils.erase_and_create_empty("/tmp/string-checkpoint-reprod")

ckpt_manager = ocp.CheckpointManager(
    dir,
    options=ocp.CheckpointManagerOptions(
        max_to_keep=5,
        create=True,
    ),
)

ckpt_manager.save(0, args=ocp.args.NumpyRandomKeySave(state))
ValueError: Error parsing object member "dtype": Unsupported data type: "object" [source locations='tensorstore/internal/json_binding/json_binding.h:384\ntensorstore/internal/json_binding/json_binding.h:524\ntensorstore/internal/json_binding/json_binding.h:861\ntensorstore/internal/json_binding/json_binding.h:825']

~~I looked through the API and the docs and it does seem like string leaf nodes should be supported. But I think the problem is that for some reason orbax-checkpoint doesn't even see the string as str (which is what StringHandler is registered for), it sees it as an object (which is technically not false).~~

Edit: The issue seems to be that the large integers in the numpy random key state get turned into dtype=object numpy arrays, which is not orbax' fault. But perhaps this should be able to be handled in some way by orbax?

I noticed in the changelog where NumpyRandomKeyCheckpointHandler is introduced that it's intended for numpy.random.get_state(), but that's just the global rng state and more or less a legacy feature, and (as far as I know) current numpy rng best practice is to instantiate and use individual rng objects through the new API, which seem to have a different kind of state and that's where the issue is.

Here is a Colab notebook reproducing the issue on orbax-checkpoint==0.5.23.

rainx0r avatar Aug 16 '24 10:08 rainx0r

Stepped through the orbax source with a debugger and realised that the problem isn't the string, it gets handled by StringHandler correctly. The problem is the really large integers in the numpy rng state. It looks like when they're handled with ScalarHandler, they get turned into numpy arrays of dtype object.

As a result I'm heavily reframing my issue to not be about inability to checkpoint string leafs but rather to checkpoint numpy random key states properly. Apologies for any confusion.

I have looked through NumpyRandomKeySave's source as well as at NumpyRandomKeyCheckpointHandler but they seem to defer to PyTreeSave / PyTreeCheckpointHandler anyway, which don't currently work.

rainx0r avatar Aug 16 '24 13:08 rainx0r

How are you planning to construct the Generator object back from the restored state?

niketkumar avatar Aug 27 '24 02:08 niketkumar

The Generator objects have a __setstate__() function that takes in the dict state exported from the Generator's __getstate__() function and that's how I currently have it implemented. Not entirely sure if these two functions are part of the Generator's public API or intended to be used directly, but another strategy that doesn't use any potentially "private" functions is to get the state through the Generator's .bit_generator.state attribute and then to set it by assigning it back.

Also currently I resorted to using JsonSave() and JsonRestore() for these RNG states.

rainx0r avatar Aug 27 '24 10:08 rainx0r

Thanks for sharing the details.

Will using numpy.random.get_state(legacy=False) meet your requirements? In that case, Orbax already supports it. Please take a look at this unit test: https://github.com/google/orbax/blob/53e2f22234717d29eca59282b496d3a6ba897b84/checkpoint/orbax/checkpoint/random_key_checkpoint_handler_test.py#L118

Alternatively, using Json should suffice too.

niketkumar avatar Aug 27 '24 21:08 niketkumar

The problem isn't really the format, as yes the format returned by numpy.random.get_state(legacy=False) is the same as the one you get by accessing Generator.bit_generator.state (or calling Generator.__getstate__()). It's more so that the BitGenerator the new Generator objects use is not MT19937, which is what the older global numpy random uses, but rather use PCG64 whose state involves extremely large integers that are unsupported by numpy arrays (which is Orbax' default mode of serialising scalars from what I can tell).

Serialising with JSON works because it just writes it into a string essentially and Python can read it back from said string without issue. I suppose I can stick with that, or put the numpy Generator objects into my checkpoint PyTree directly instead of their state and then write a TypeHandler for numpy.random.Generator that extracts / sets the state and serialises with JSON.

rainx0r avatar Aug 28 '24 09:08 rainx0r

Thanks for clarifying the difference between MT19937 and PCG64!

A JSON based solution is ideal for this scenario. I will look into it.

niketkumar avatar Aug 28 '24 14:08 niketkumar

@rainx0r Sorry for late response because this issue just brought to my attention this week. Since the generator state is json serializable, you can simply use the orbax.checkpoint.args.JsonSave to save. Here are the examples how to use JsonSave:

ChromeHearts avatar Feb 07 '25 22:02 ChromeHearts