Question: States, `vmap`-construction-of-layers and `lax.scan`'ing.
Hello!
I'm having some trouble with states and the construction of a model that filter_vmap-constructs layers.
I've checked your FAQ's and stateful-ops page, but I can't get this model working as I expect...
-
In the normal
ModelI build a model as you'd expect. I batch the states for application to a batch of data. All good. -
in the
OddModelI try to do the same thing, butvmapconstruct my layers andlax.scanover them during my__call__. This causes an error because the state is batched when it is returned with an initialeqx.nn.make_with_state.
This error is expected, but how do I ensure my OddModel works as expected? I've tried sub-stating but I'm confused because I don't actually need to vmap the layers (which as I understand is what a substate is for).
What am I missing? Thanks as always!
(This is an MWE of a very large transformer-based model that uses k-v caches parameterised with eqx.nn.States)
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
class WeirdLayer(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
index: eqx.nn.StateIndex
def __init__(self, key):
self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.index = eqx.nn.StateIndex(jnp.zeros((4,)))
def __call__(self, x, state):
y = state.get(self.index)
x = self.linear1(x)
x = jax.nn.relu(x)
x = x + y
print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
new_state = state.set(self.index, x)
x = self.linear2(x)
return x, new_state
class Model(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model does the normal building of layers with a loop
list of layers and for loop in __call__
"""
self.layers = [WeirdLayer(key) for _ in range(3)]
def __call__(self, x, state):
for l in self.layers:
x, state = l(x, state)
return x, state
# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well
class OddModel(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model uses vmap'd construction of layers
and a lax.scan over them in __call__
"""
keys = jr.split(key, 3)
self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)
def __call__(self, x, state):
all_params, static = eqx.partition(self.layers, eqx.is_array)
def _step(x_and_state, params):
x, state = x_and_state
layer = eqx.combine(params, static)
substate = state.substate(self.layers)
x, substate = layer(x, substate)
state = state.update(substate)
return (x, state), None
(x, state), _ = jax.lax.scan(_step, (x, state), all_params)
return x, state
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)
jax.debug.print("{}", state) # Batched, as expected
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y = jax.vmap(m)(x, states) # Error!
print(y)
Hi Jed,
welcome to Equinox :)
The vmap operation after creating OddModel is redundant, since your model is already batched when you create it. That means that the state is batched too.
Your code runs if the last lines are something like this
...
m, state = eqx.nn.make_with_state(OddModel)(key)
y, state = jax.vmap(m)(x, state)
print(y, y.shape)
Hi Johanna!
Thank you for the reply! I might be missing something here... If I change the batch size e.g. x = jnp.ones((7, 4)) the error returns.
It seems like the state (returned by make_with_state) is a batched state of a single module (meaning the batch axis of this state m, state = eqx.nn.make_with_state(OddModel)(key) has 3 elements, one for each WeirdLayer in the OddModel model). I'd expect that from the vmap'd construction of the layers in the OddModel.
This means when I vmap over the states, the batch axis of the data x and the states are not aligned. I would do this vmap since my application requires it (also this would be done e.g. for eqx.nn.BatchNorm during training.
Cheers!
A few points based on my understanding of states. First is that the substate doesn't do anything here (since you get the substate based on the layers, which already contain the whole state). Second is that the key issue is that you are scanning over states, but not actually scanning over them (i.e. you have a (3, 4) state and pass that in each time, when you really want to pass in a (4,) state each time). Since the docs example is vmaping over them, but here you are scanning over them, it's slightly different. I wrote something that approached it from that direction which seemed to work (I dislike the manual state interaction, it seems brutally inelegant, so I will update it if I think of something more clever).
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
class WeirdLayer(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
index: eqx.nn.StateIndex
def __init__(self, key):
self.linear1 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.linear2 = eqx.nn.Linear(in_features=4, out_features=4, key=key)
self.index = eqx.nn.StateIndex(jnp.zeros((4,)))
def __call__(self, x, state):
y = state.get(self.index)
print("y", y.shape)
x = self.linear1(x)
x = jax.nn.relu(x)
x = x + y
print("x", x.shape) # This becomes 3, 4 when this layer is used in OddModel, and therefore breaks!
new_state = state.set(self.index, x)
x = self.linear2(x)
return x, new_state
class Model(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model does the normal building of layers with a loop
list of layers and for loop in __call__
"""
self.layers = [WeirdLayer(key) for _ in range(3)]
def __call__(self, x, state):
for l in self.layers:
x, state = l(x, state)
return x, state
key = jax.random.key(42)
# Use the normal model, see what happens
x = jnp.ones((3, 4))
m, state = eqx.nn.make_with_state(Model)(key)
states = eqx.filter_vmap(lambda: state, axis_size=len(x))()
y, state = jax.vmap(m)(x, states)
print("Out:", y.shape) # All is well
class OddModel(eqx.Module):
layers: list[WeirdLayer]
def __init__(self, key):
"""
This model uses vmap'd construction of layers
and a lax.scan over them in __call__
"""
keys = jr.split(key, 3)
self.layers = eqx.filter_vmap(lambda key: WeirdLayer(key))(keys)
def __call__(self, x, state):
all_params, static = eqx.partition(self.layers, eqx.is_array)
def _step(x_and_state, params):
x, state, layer_ind = x_and_state
layer = eqx.combine(params, static)
substate = jax.tree.map(lambda x: x[layer_ind], state)
x, substate = layer(x, substate)
substate = jax.tree.map(lambda x, y: x.at[layer_ind].set(y), state, substate)
state = state.update(substate)
return (x, state, layer_ind + 1), None
(x, state, _), _ = jax.lax.scan(_step, (x, state, 0), all_params)
return x, state
x = jnp.ones((7, 4))
m, state = eqx.nn.make_with_state(OddModel)(key)
print(state)
y = jax.vmap(m, in_axes=(0, None))(x, state) # No Error!
print(y)
@lockwo - thank you for this insight!
This was what I was trying just now (didn't work out the tree mapping, been confused by tracers and what not...).
If it's useful, I tried partitioning the state in the same way as the layers (e.g. state = eqx.combine(state_params, static_state)) but lax.scan' doesn't like iterating over (all_params, all_states)instead ofall_params`.
I guess that's quite obvious but you can't use an xs=jnp.arange(n_layers) to index the state/params due to the tracers involved (i'm 99% sure I tried exactly that, I've got a long notebook of trials...).
Thank you both! I will also let you know if I get this fixed :)
lax.scandoesn't like iterating over(all_params, all_states)instead ofall_params.
This surprises me. I think this is the correct solution! You have a different state for each iteration of the scan.
I think @lockwo's example is nearly correct, and you just need to modify his final __call__ to look like this:
def __call__(self, x, all_state):
all_params, static = eqx.partition(self.layers, eqx.is_array)
def _step(x, params__state):
params, state = params__state
layer = eqx.combine(params, static)
x, state = layer(x, state)
return x, state
x, all_state = jax.lax.scan(_step, x, (all_params, all_state))
return x, all_state
On which note, heads-up @lockwo that if you were to take a @jax.grad of the scan you have, then you'd run into the same XLA bug I mentioned here ('And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass!') Specifically you shouldn't do grad-of-loop-of-inplace if you read from the buffer you're also writing to. (But if it's never read during the loop, and only written to, then it's fine.)
Ah yes, that's good, I forgot you don't need to call .update since the substate wasn't meaningful (and can therefore make it more elegant scaning over both).
Thanks for the bug reminder too ;), hopefully by the time I remember all the bugs in XLA I will also be putting bounties on them.
Awesome stuff, I was hoping you'd see this @patrick-kidger :).
Yeah I tried the param xs arg for lax.scan but I must have had a hidden bug or something in my long, long notebook. It makes sense as just another pytree there.
Thanks everyone, I learnt a lot here!