equinox
equinox copied to clipboard
Batchnorm example doesn't work
Hi there!
I'm sorry if this issue is already known, but apparently the BatchNorm layer isn't working correctly. For instance, the example code listed in the documentation fails with ValueError: Unable to parse module assembly (see diagnostics)
:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
key = jr.PRNGKey(0)
mkey, dkey = jr.split(key)
model = eqx.nn.Sequential([
eqx.nn.Linear(in_features=3, out_features=4, key=mkey),
eqx.experimental.BatchNorm(input_size=4, axis_name="batch"),
])
x = jr.normal(dkey, (10, 3))
jax.vmap(model, axis_name="batch")(x)
# ValueError: Unable to parse module assembly (see diagnostics)
Hmm. This works fine on my machine. (CPU, Equinox version 0.7.1; JAX version 0.3.17, jaxlib version 0.3.15)
The error itself isn't familiar to me either. This looks like some deeper error somewhere in your system.
@patrick-kidger I ran some more tests, and it seems that this error only occurs when run on a GPU backend. This colab can reproduce it when the GPU is enabled, but there's no error when either the CPU or TPU are used instead.
Oh, interesting. Yup, I can reproduce this on my GPU machine too.
The bad news is that I don't think this is something that can be fixed quickly. This looks like an obscure error from JAX doing something weird.
The good news is that core JAX has some incoming changes, that might see equinox.experimental.BatchNorm
switching to a totally different, and robust, implementation.
(For context about why BatchNorm
is such a tricky operation to support: it's because getting/setting its running statistics is a side effect. Meanwhile JAX, being functional, basically doesn't support side effects! Supporting this properly, without compromising on all the lovely things that make JAX efficient, is pretty difficult.)
So right now I think the answer is a somewhat unsatisfying "let's wait and see".
This seems to work now, no ? At least on the colab. Has anything changed since ?
Seems to work with equinox==0.9.2
. One small thing is that without jitting it seems to be very slow:
import equinox as eqx
import jax.random as jrandom
import jax.nn as jnn
import jax.numpy as jnp
import jax
import time
class Network(eqx.Module):
net: eqx.Module
def __init__(self, in_size, out_size, width, depth, *, key, bn=True):
keys = jrandom.split(key, depth + 1)
layers = []
if depth == 0:
layers.append(eqx.nn.Linear(in_size, out_size, key=keys[0]))
else:
layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
if bn:
layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
for i in range(depth - 1):
layers.append(eqx.nn.Linear(width, width, key=keys[i + 1]))
if bn:
layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
layers.append(eqx.nn.Lambda(jnn.relu))
layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))
self.net = eqx.nn.Sequential(layers)
def __call__(self, x):
return self.net(x)
if __name__=="__main__":
key = jrandom.PRNGKey(0)
init_key, data_key = jrandom.split(key, 2)
net = Network(10, 5, 3, 300, key=init_key, bn=False)
bn_net = Network(10, 5, 3, 300, key=init_key, bn=True)
x = jrandom.normal(data_key, (32, 10))
func = jax.vmap(net, axis_name="batch")
jitted = jax.jit(func)
bn_func = jax.vmap(bn_net, axis_name="batch")
bn_jitted = jax.jit(func)
# compile
jitted(x)
bn_jitted(x)
start = time.time()
y = func(x)
finish = time.time()
print(f"Wout BN / Wout JIT took: {finish-start:.2f}")
start = time.time()
y = jitted(x)
finish = time.time()
print(f"Wout BN / With JIT took: {finish-start:.2f}")
start = time.time()
y = bn_func(x)
finish = time.time()
print(f"With BN / Wout JIT took: {finish-start:.2f}")
start = time.time()
y = bn_jitted(x)
finish = time.time()
print(f"With BN / With JIT took: {finish-start:.2f}")
On my gpu machine this outputs
Wout BN / Wout JIT took: 0.42
Wout BN / With JIT took: 0.00
With BN / Wout JIT took: 15.67
With BN / With JIT took: 0.00
Closing as the experimental BatchNorm has now been de-experimental'ised into eqx.nn.BatchNorm
, which should fix this.