equinox icon indicating copy to clipboard operation
equinox copied to clipboard

How to serialize a model + state pair?

Open Artur-Galstyan opened this issue 2 months ago • 2 comments

I'm working on a little project that can ease PyTorch model conversion to your own JAX model

Shameless advertisement image

and after setting all the weights, and biases and states, I wanted to serialise everything.

In other words, I have this:

    in_size = 784
    out_size = 10
    width_size = 64
    depth = 2
    key = jax.random.PRNGKey(22)

    class EqxMLP(eqx.Module):
        mlp: eqx.nn.MLP
        batch_norm: eqx.nn.BatchNorm

        def __init__(self, in_size, out_size, width_size, depth, key):
            self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
            self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")

        def __call__(self, x, state):
            return self.batch_norm(self.mlp(x), state)
 
    jax_model, state = eqx.nn.make_with_state(EqxMLP)(in_size, out_size, width_size, depth, key)
    # all of the custom logic which loaded the torch weights and the batch norm states into my network
    new_model, new_state = magic(...)
    # how to serialise this?

And now my questions are:

  1. Do I serialise both independently?
  2. Do I combine them and then serialise the combined pytree?
  3. if 2): how do I combine them again?

My understanding is that eqx.nn.make_with_state simply splits the model by whether the leaf is of type StateIndex or not. I suppose I just have to invert that, but I don't know how to :(

Artur-Galstyan avatar May 06 '24 20:05 Artur-Galstyan