possible penzai pz.ParameterValue TypeHandler? Or, how to use orbax with penzai?
Hi,
I'm trying to use orbax to checkpoint arbitrary penzai models, which have parameters of type pz.ParameterValue, which themselves contain pz.NamedArray instances. These instances contain named axes. So, I thought I'd try my hand at implementing class derived from type_handlers.TypeHandler. However, I can't seem to see where in this workflow the axis names would be stored. I saw there is a TypeHandler.metadata method, but that seems to be called only during restore. And, TypeHandler.serialize doesn't seem to provide an opportunity to specialize except at a very low level, at the tensorstore level.
Am I missing something else? It would be nice to be able to use orbax with penzai models.
On the penzai side, they claim that orbax can be used, but there is no example of saving / loading an arbitrary model.
This is the first time I've heard about penzai, so my advice may be somewhat limited. I do see some examples of loading Orbax checkpoints, but nothing about saving: https://penzai.readthedocs.io/en/stable/notebooks/induction_heads.html#loading-gemma.
It seems like the easiest thing to do might be to extract the base jax.Array representation from the pz.NamedArray. Named axes could potentially be stored as a separate item, represented as a nested tree where the values are tuples of strings. JsonCheckpointHandler would be appropriate for saving this.
If you want to implement a TypeHandler though, I'd suggest implement one that delegates to an underlying ArrayHandler for saving the jax.Array. Then, serialize should write some additional file(s) (e.g. JSON) that represent named_axes and any other relevant properties. ArrayHandler itself is a good model for this - it stores some sharding metadata using the TensorStore JSON driver.