equinox
equinox copied to clipboard
BatchNorm ValueError with larger ndims.
Hi, Thank you for the awesome library, I appreciate all the effort that has been put into it.
I am trying to adapt a pytorch time series transformer to equinox, the transformer uses BatchNorm instead of LayerNorm. However, for tensors with a higher number of dimensions I am running into some issues.
I took this from the test code for equinox
import equinox as eqx
import jax.random as jrandom
import jax
def getkey():
return jrandom.PRNGKey(1)
x0 = jrandom.uniform(getkey(), (5,))
x1 = jrandom.uniform(getkey(), (10, 5))
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))
bn = eqx.nn.BatchNorm(5, "batch")
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
for ii,x in enumerate((x1, x2, x3)):
print(f"Running test {ii}")
out, state = vbn(x, state)
assert out.shape == x.shape
running_mean, running_var = state.get(bn.state_index)
assert running_mean.shape == (5,)
assert running_var.shape == (5,)
Which seems to generate a ValueError
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[8], line 17
15 for ii,x in enumerate((x1, x2, x3)):
16 print(f"Running test {ii}")
---> 17 out, state = vbn(x, state)
18 assert out.shape == x.shape
19 running_mean, running_var = state.get(bn.state_index)
[... skipping hidden 3 frame]
File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/site-packages/equinox/nn/_batch_norm.py:165, in BatchNorm.__call__(self, x, state, key, inference)
163 running_mean = lax.select(first_time, batch_mean, running_mean)
164 running_var = lax.select(first_time, batch_var, running_var)
--> 165 state = state.set(self.state_index, (running_mean, running_var))
167 def _norm(y, m, v, w, b):
168 out = (y - m) / jnp.sqrt(v + self.eps)
File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/site-packages/equinox/nn/_stateful.py:169, in State.set(self, item, value)
167 value = jtu.tree_map(jnp.asarray, value)
168 if jax.eval_shape(lambda: old_value) != jax.eval_shape(lambda: value):
--> 169 raise ValueError("Old and new values have different structures.")
170 state = self._state.copy() # pyright: ignore
171 state[item.marker] = value
ValueError: Old and new values have different structures.
The versions of the respective libraries in my environment:
equinox 0.11.2 pypi_0 pypi
flowjax 10.1.0 pypi_0 pypi
jax 0.4.19 pypi_0 pypi
jaxlib 0.4.14 pypi_0 pypi
jaxtyping 0.2.23 pypi_0 pypi
Has the way BatchNorm is used changed? Or is there something specific to my system that is associated with this? Any advice on how to proceed would be greatly appreciated.
Hmm. I can't reproduce this error I'm afraid: the MWE you've provided runs without errors for me. (In my case I have equinox==0.11.2,
jax,jaxlib==0.4.19
, jaxtyping==0.2.23
.)
You could try running your script as ipython your_file.py --pdb
to check the value of each local variable at different parts of the stack, and see where things might have gone wrong.
Thanks for the reply. I figured it out. I have some environmental configurations that caused the error.
The following import and settings:
from jax.config import config
config.update("jax_enable_x64", True)
Ah, that'll be it. Looks like you'll need to initialise the batch norm state with the appropriate dtype:
bn = eqx.nn.BatchNorm(5, "batch", dtype=jax.numpy.float64)
I can see that the error message isn't super helpful here -- I'll try to improve that.
Side note, the use of eqx.nn.State
isn't super memory efficient -- the last few Equinox releases have included eqx.nn.make_with_state
, e.g. bn, state = eqx.nn.make_with_state(BatchNorm)(5, "batch", dtype=jax.numpy.float64)
that improves on this.
cool thanks a lot. That makes sense. I was trying to figure out what the right order of operations would be (or if it matters) between using vmap
and make_with_state
. I originally thought I had to use multiple vmaps to make BatchNorm work with multiple dimensions. Though I ran into a catch 22 issue seeing as you need to actually use the function arguments to use make_with_state
to get the state. But using make_with_state
first then vmap also gave me the above error. So I thought that was wrong.
Anyway, it works now, so on to the next part of converting the pytorch code. Thanks a bunch though, appreciate it.