equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Batchnorm example doesn't work

Open joaomarcoscsilva opened this issue 1 year ago • 3 comments

Hi there!

I'm sorry if this issue is already known, but apparently the BatchNorm layer isn't working correctly. For instance, the example code listed in the documentation fails with ValueError: Unable to parse module assembly (see diagnostics):

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
    eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
    eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])

x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# ValueError: Unable to parse module assembly (see diagnostics)

joaomarcoscsilva avatar Sep 22 '22 00:09 joaomarcoscsilva

Hmm. This works fine on my machine. (CPU, Equinox version 0.7.1; JAX version 0.3.17, jaxlib version 0.3.15)

The error itself isn't familiar to me either. This looks like some deeper error somewhere in your system.

patrick-kidger avatar Sep 22 '22 00:09 patrick-kidger

@patrick-kidger I ran some more tests, and it seems that this error only occurs when run on a GPU backend. This colab can reproduce it when the GPU is enabled, but there's no error when either the CPU or TPU are used instead.

joaomarcoscsilva avatar Sep 22 '22 00:09 joaomarcoscsilva

Oh, interesting. Yup, I can reproduce this on my GPU machine too.

The bad news is that I don't think this is something that can be fixed quickly. This looks like an obscure error from JAX doing something weird.

The good news is that core JAX has some incoming changes, that might see equinox.experimental.BatchNorm switching to a totally different, and robust, implementation.

(For context about why BatchNorm is such a tricky operation to support: it's because getting/setting its running statistics is a side effect. Meanwhile JAX, being functional, basically doesn't support side effects! Supporting this properly, without compromising on all the lovely things that make JAX efficient, is pretty difficult.)

So right now I think the answer is a somewhat unsatisfying "let's wait and see".

patrick-kidger avatar Sep 22 '22 21:09 patrick-kidger