equinox
equinox copied to clipboard
Weird `jax` error when trying vmap twice while using batchnorm
I'm trying to
- Utilize the init/apply design pattern.
- Vmap a module twice (where the module uses batchnorm) in order to handle inputs of shape
[n_samples, batch_dim, feature_dim]
instead of just[batch_dim, feature_dim]
.
import jax
import jax.numpy as jnp
import jax.random as random
from typing import Any, Callable
import equinox as eqx
State = eqx.nn._stateful.State
def batch_model(model: eqx.Module) -> Callable:
# see BatchNorm: https://docs.kidger.site/equinox/api/nn/normalisation/
return jax.vmap(model, in_axes=(0, None), out_axes=(0, None), axis_name="batch")
def init_apply_eqx_model(
model: tuple[Any, State]
) -> tuple[Callable, Callable]:
model, state = model
params, static = eqx.partition(model, eqx.is_inexact_array)
def init():
return params, state
def apply(params, state, input):
model = eqx.combine(params, static)
batched_model = batch_model(model)
out, updates = batched_model(input, state)
return out, updates
return init, apply
@eqx.nn.make_with_state
class Model(eqx.Module):
linear1: eqx.nn.Linear
linear2: eqx.nn.Linear
norm: eqx.nn.BatchNorm
def __init__(self, key):
key1, key2 = random.split(key)
self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key1)
self.norm = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key2)
def __call__(self, x, state):
x = self.linear1(x)
x, state = self.norm(x, state)
x = jax.nn.relu(x)
x = self.linear2(x)
return x, state
eqx_model = Model(random.key(0))
init, apply = init_apply_eqx_model(eqx_model)
params, state = init()
# vmap model so that we can have input of shape e.g.
# [4, 10, 32]
# THIS WORKS
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=0)
out, update = new_apply(params, state, jnp.ones((4, 10, 32)))
print(out.shape)
# THIS DOES NOT WORK
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=(0, None))
new_apply(params, state, jnp.ones((4, 10, 32)))
# ERROR: 'ValueError: vmap has mapped output but out_axes is None'