equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Question: States, `vmap`-construction-of-layers and `lax.scan`'ing.

Open homerjed opened this issue 1 year ago • 7 comments

Hello!

I'm having some trouble with states and the construction of a model that filter_vmap-constructs layers.

I've checked your FAQ's and stateful-ops page, but I can't get this model working as I expect...

  • In the normal Model I build a model as you'd expect. I batch the states for application to a batch of data. All good.

  • in the OddModel I try to do the same thing, but vmap construct my layers and lax.scan over them during my __call__. This causes an error because the state is batched when it is returned with an initial eqx.nn.make_with_state.

This error is expected, but how do I ensure my OddModel works as expected? I've tried sub-stating but I'm confused because I don't actually need to vmap the layers (which as I understand is what a substate is for).

What am I missing? Thanks as always!

(This is an MWE of a very large transformer-based model that uses k-v caches parameterised with eqx.nn.States)

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx 

class WeirdLayer(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    index: eqx.nn.StateIndex

    def __init__(self, key):
        self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.index = eqx.nn.StateIndex(jnp.zeros((4,)))

    def __call__(self, x, state):
        y = state.get(self.index)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = x + y
        print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
        new_state = state.set(self.index, x)
        x = self.linear2(x)
        return x, new_state

class Model(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model does the normal building of layers with a loop
            list of layers and for loop in __call__
        """
        self.layers = [WeirdLayer(key) for _ in range(3)]
    
    def __call__(self, x, state):
        for l in self.layers:
            x, state = l(x, state)
        return x, state

# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well

class OddModel(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model uses vmap'd construction of layers 
            and a lax.scan over them in __call__
        """
        keys = jr.split(key, 3)
        self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)

    def __call__(self, x, state):
        all_params, static = eqx.partition(self.layers, eqx.is_array)

        def _step(x_and_state, params):
            x, state = x_and_state
            layer = eqx.combine(params, static)
            substate = state.substate(self.layers)
            x, substate = layer(x, substate)
            state = state.update(substate)
            return (x, state), None

        (x, state), _ = jax.lax.scan(_step, (x, state), all_params)
        return x, state

x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)

jax.debug.print("{}", state) # Batched, as expected

states = eqx.filter_vmap(lambda: state, axis_size=len(x))()

y = jax.vmap(m)(x, states) # Error!
print(y)

homerjed avatar Jan 16 '25 15:01 homerjed

Hi Jed,

welcome to Equinox :) The vmap operation after creating OddModel is redundant, since your model is already batched when you create it. That means that the state is batched too.

Your code runs if the last lines are something like this

...
m, state = eqx.nn.make_with_state(OddModel)(key)
y, state = jax.vmap(m)(x, state) 
print(y, y.shape)

johannahaffner avatar Jan 16 '25 17:01 johannahaffner

Hi Johanna!

Thank you for the reply! I might be missing something here... If I change the batch size e.g. x = jnp.ones((7, 4)) the error returns.

It seems like the state (returned by make_with_state) is a batched state of a single module (meaning the batch axis of this state m, state = eqx.nn.make_with_state(OddModel)(key) has 3 elements, one for each WeirdLayer in the OddModel model). I'd expect that from the vmap'd construction of the layers in the OddModel.

This means when I vmap over the states, the batch axis of the data x and the states are not aligned. I would do this vmap since my application requires it (also this would be done e.g. for eqx.nn.BatchNorm during training.

Cheers!

homerjed avatar Jan 16 '25 17:01 homerjed

A few points based on my understanding of states. First is that the substate doesn't do anything here (since you get the substate based on the layers, which already contain the whole state). Second is that the key issue is that you are scanning over states, but not actually scanning over them (i.e. you have a (3, 4) state and pass that in each time, when you really want to pass in a (4,) state each time). Since the docs example is vmaping over them, but here you are scanning over them, it's slightly different. I wrote something that approached it from that direction which seemed to work (I dislike the manual state interaction, it seems brutally inelegant, so I will update it if I think of something more clever).

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx 

class WeirdLayer(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    index: eqx.nn.StateIndex

    def __init__(self, key):
        self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
        self.index = eqx.nn.StateIndex(jnp.zeros((4,)))

    def __call__(self, x, state):
        y = state.get(self.index)
        print("y", y.shape)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = x + y
        print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
        new_state = state.set(self.index, x)
        x = self.linear2(x)
        return x, new_state

class Model(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model does the normal building of layers with a loop
            list of layers and for loop in __call__
        """
        self.layers = [WeirdLayer(key) for _ in range(3)]
    
    def __call__(self, x, state):
        for l in self.layers:
            x, state = l(x, state)
        return x, state

key = jax.random.key(42)
# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well

class OddModel(eqx.Module):
    layers: list[WeirdLayer]

    def __init__(self, key):
        """ 
            This model uses vmap'd construction of layers 
            and a lax.scan over them in __call__
        """
        keys = jr.split(key, 3)
        self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)

    def __call__(self, x, state):
        all_params, static = eqx.partition(self.layers, eqx.is_array)

        def _step(x_and_state, params):
            x, state, layer_ind = x_and_state
            layer = eqx.combine(params, static)
            substate = jax.tree.map(lambda x: x[layer_ind], state)
            x, substate = layer(x, substate)
            substate = jax.tree.map(lambda x, y: x.at[layer_ind].set(y), state, substate)
            state = state.update(substate)
            return (x, state, layer_ind + 1), None

        (x, state, _), _ = jax.lax.scan(_step, (x, state, 0), all_params)
        return x, state

x = jnp.ones((7, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)

print(state)

y = jax.vmap(m, in_axes=(0, None))(x, state) # No Error!
print(y)

lockwo avatar Jan 16 '25 18:01 lockwo

@lockwo - thank you for this insight!

This was what I was trying just now (didn't work out the tree mapping, been confused by tracers and what not...).

If it's useful, I tried partitioning the state in the same way as the layers (e.g. state = eqx.combine(state_params, static_state)) but lax.scan' doesn't like iterating over (all_params, all_states)instead ofall_params`.

I guess that's quite obvious but you can't use an xs=jnp.arange(n_layers) to index the state/params due to the tracers involved (i'm 99% sure I tried exactly that, I've got a long notebook of trials...).

Thank you both! I will also let you know if I get this fixed :)

homerjed avatar Jan 16 '25 21:01 homerjed

lax.scan doesn't like iterating over (all_params, all_states) instead of all_params.

This surprises me. I think this is the correct solution! You have a different state for each iteration of the scan.

I think @lockwo's example is nearly correct, and you just need to modify his final __call__ to look like this:

def __call__(self, x, all_state):
    all_params, static = eqx.partition(self.layers, eqx.is_array)

    def _step(x, params__state):
        params, state = params__state
        layer = eqx.combine(params, static)
        x, state = layer(x, state)
        return x, state

    x, all_state = jax.lax.scan(_step, x, (all_params, all_state))
    return x, all_state

On which note, heads-up @lockwo that if you were to take a @jax.grad of the scan you have, then you'd run into the same XLA bug I mentioned here ('And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass!') Specifically you shouldn't do grad-of-loop-of-inplace if you read from the buffer you're also writing to. (But if it's never read during the loop, and only written to, then it's fine.)

patrick-kidger avatar Jan 16 '25 22:01 patrick-kidger

Ah yes, that's good, I forgot you don't need to call .update since the substate wasn't meaningful (and can therefore make it more elegant scaning over both).

Thanks for the bug reminder too ;), hopefully by the time I remember all the bugs in XLA I will also be putting bounties on them.

lockwo avatar Jan 16 '25 22:01 lockwo

Awesome stuff, I was hoping you'd see this @patrick-kidger :).

Yeah I tried the param xs arg for lax.scan but I must have had a hidden bug or something in my long, long notebook. It makes sense as just another pytree there.

Thanks everyone, I learnt a lot here!

homerjed avatar Jan 17 '25 09:01 homerjed