equinox icon indicating copy to clipboard operation
equinox copied to clipboard

BatchNorm ValueError with larger ndims.

Open zeneofa opened this issue 1 year ago • 4 comments

Hi, Thank you for the awesome library, I appreciate all the effort that has been put into it.

I am trying to adapt a pytorch time series transformer to equinox, the transformer uses BatchNorm instead of LayerNorm. However, for tensors with a higher number of dimensions I am running into some issues.

I took this from the test code for equinox

import equinox as eqx
import jax.random as jrandom
import jax
def getkey():
    return jrandom.PRNGKey(1)

x0 = jrandom.uniform(getkey(), (5,))
x1 = jrandom.uniform(getkey(), (10, 5))
x2 = jrandom.uniform(getkey(), (10, 5, 6))
x3 = jrandom.uniform(getkey(), (10, 5, 7, 8))
bn = eqx.nn.BatchNorm(5, "batch")
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))

for ii,x in enumerate((x1, x2, x3)):
    print(f"Running test {ii}")
    out, state = vbn(x, state)
    assert out.shape == x.shape
    running_mean, running_var = state.get(bn.state_index)
    assert running_mean.shape == (5,)
    assert running_var.shape == (5,)

Which seems to generate a ValueError

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[8], line 17
     15 for ii,x in enumerate((x1, x2, x3)):
     16     print(f"Running test {ii}")
---> 17     out, state = vbn(x, state)
     18     assert out.shape == x.shape
     19     running_mean, running_var = state.get(bn.state_index)

    [... skipping hidden 3 frame]

File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/site-packages/equinox/nn/_batch_norm.py:165, in BatchNorm.__call__(self, x, state, key, inference)
    163     running_mean = lax.select(first_time, batch_mean, running_mean)
    164     running_var = lax.select(first_time, batch_var, running_var)
--> 165     state = state.set(self.state_index, (running_mean, running_var))
    167 def _norm(y, m, v, w, b):
    168     out = (y - m) / jnp.sqrt(v + self.eps)

File ~/miniconda3/envs/best-python-environment-3.10/lib/python3.10/site-packages/equinox/nn/_stateful.py:169, in State.set(self, item, value)
    167 value = jtu.tree_map(jnp.asarray, value)
    168 if jax.eval_shape(lambda: old_value) != jax.eval_shape(lambda: value):
--> 169     raise ValueError("Old and new values have different structures.")
    170 state = self._state.copy()  # pyright: ignore
    171 state[item.marker] = value

ValueError: Old and new values have different structures.

The versions of the respective libraries in my environment:

equinox                   0.11.2                   pypi_0    pypi
flowjax                   10.1.0                   pypi_0    pypi
jax                       0.4.19                   pypi_0    pypi
jaxlib                    0.4.14                   pypi_0    pypi
jaxtyping                 0.2.23                   pypi_0    pypi

Has the way BatchNorm is used changed? Or is there something specific to my system that is associated with this? Any advice on how to proceed would be greatly appreciated.

zeneofa avatar Nov 15 '23 22:11 zeneofa

Hmm. I can't reproduce this error I'm afraid: the MWE you've provided runs without errors for me. (In my case I have equinox==0.11.2, jax,jaxlib==0.4.19, jaxtyping==0.2.23.)

You could try running your script as ipython your_file.py --pdb to check the value of each local variable at different parts of the stack, and see where things might have gone wrong.

patrick-kidger avatar Nov 15 '23 22:11 patrick-kidger

Thanks for the reply. I figured it out. I have some environmental configurations that caused the error.

The following import and settings:

from jax.config import config
config.update("jax_enable_x64", True)

zeneofa avatar Nov 15 '23 22:11 zeneofa

Ah, that'll be it. Looks like you'll need to initialise the batch norm state with the appropriate dtype:

bn = eqx.nn.BatchNorm(5, "batch", dtype=jax.numpy.float64)

I can see that the error message isn't super helpful here -- I'll try to improve that.

Side note, the use of eqx.nn.State isn't super memory efficient -- the last few Equinox releases have included eqx.nn.make_with_state, e.g. bn, state = eqx.nn.make_with_state(BatchNorm)(5, "batch", dtype=jax.numpy.float64) that improves on this.

patrick-kidger avatar Nov 15 '23 22:11 patrick-kidger

cool thanks a lot. That makes sense. I was trying to figure out what the right order of operations would be (or if it matters) between using vmap and make_with_state. I originally thought I had to use multiple vmaps to make BatchNorm work with multiple dimensions. Though I ran into a catch 22 issue seeing as you need to actually use the function arguments to use make_with_state to get the state. But using make_with_state first then vmap also gave me the above error. So I thought that was wrong.

Anyway, it works now, so on to the next part of converting the pytorch code. Thanks a bunch though, appreciate it.

zeneofa avatar Nov 15 '23 23:11 zeneofa