equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Pattern for lossless round trip of module serialization without storing side information required by __init__?

Open cottrell opened this issue 9 months ago • 14 comments

Woudn't it make sense to add something to wrap all this to make deserialization work? Currently it's lossy and you need to basically either define a init_param_from_params function or store the init params separately somewhere.

I might be missing something included elsewhere.

And what I mean is that eqx.nn.MLP(2, 2, 2, 2, should not be needed to "deseralize". For example in pure jax you write everything based on params and I simply have jsonifiers of params and it all works. The "init_params" are not needed.

A special classmethod might make sense here. I'm not deep enough into the internals of equinox yet. But the philosophy sounds like it should all separate and merely be a nice way of organizing jax params.

def from_params(...):
   ...

might make sense. I'm not deep enough into the

https://docs.kidger.site/equinox/api/serialisation/

import equinox as eqx
import jax.random as jr

model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)

# To partially load weights: in this case load everything except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)

cottrell avatar Sep 29 '23 15:09 cottrell