equinox
equinox copied to clipboard
Pattern for lossless round trip of module serialization without storing side information required by __init__?
Woudn't it make sense to add something to wrap all this to make deserialization work? Currently it's lossy and you need to basically either define a init_param_from_params function or store the init params separately somewhere.
I might be missing something included elsewhere.
And what I mean is that eqx.nn.MLP(2, 2, 2, 2,
should not be needed to "deseralize". For example in pure jax you write everything based on params and I simply have jsonifiers of params and it all works. The "init_params" are not needed.
A special classmethod might make sense here. I'm not deep enough into the internals of equinox yet. But the philosophy sounds like it should all separate and merely be a nice way of organizing jax params.
def from_params(...):
...
might make sense. I'm not deep enough into the
https://docs.kidger.site/equinox/api/serialisation/
import equinox as eqx
import jax.random as jr
model_original = eqx.nn.MLP(2, 2, 2, 2, key=jr.PRNGKey(0))
eqx.tree_serialise_leaves("some_filename.eqx", model_original)
model_loaded = eqx.tree_deserialise_leaves("some_filename.eqx", model_original)
# To partially load weights: in this case load everything except the final layer.
model_partial = eqx.tree_at(lambda mlp: mlp.layers[-1], model_loaded, model_original)