equinox
equinox copied to clipboard
How to serialize a model + state pair?
I'm working on a little project that can ease PyTorch model conversion to your own JAX model
Shameless advertisement
and after setting all the weights, and biases and states, I wanted to serialise everything.
In other words, I have this:
in_size = 784
out_size = 10
width_size = 64
depth = 2
key = jax.random.PRNGKey(22)
class EqxMLP(eqx.Module):
mlp: eqx.nn.MLP
batch_norm: eqx.nn.BatchNorm
def __init__(self, in_size, out_size, width_size, depth, key):
self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
def __call__(self, x, state):
return self.batch_norm(self.mlp(x), state)
jax_model, state = eqx.nn.make_with_state(EqxMLP)(in_size, out_size, width_size, depth, key)
# all of the custom logic which loaded the torch weights and the batch norm states into my network
new_model, new_state = magic(...)
# how to serialise this?
And now my questions are:
- Do I serialise both independently?
- Do I combine them and then serialise the combined pytree?
- if 2): how do I combine them again?
My understanding is that eqx.nn.make_with_state simply splits the model by whether the leaf is of type StateIndex
or not. I suppose I just have to invert that, but I don't know how to :(