equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Weird `jax` error when trying vmap twice while using batchnorm

Open PaulScemama opened this issue 2 months ago • 5 comments

I'm trying to

  1. Utilize the init/apply design pattern.
  2. Vmap a module twice (where the module uses batchnorm) in order to handle inputs of shape [n_samples, batch_dim, feature_dim] instead of just [batch_dim, feature_dim].
import jax

import jax.numpy as jnp
import jax.random as random

from typing import Any, Callable
import equinox as eqx
State = eqx.nn._stateful.State

def batch_model(model: eqx.Module) -> Callable:
    # see BatchNorm: https://docs.kidger.site/equinox/api/nn/normalisation/
    return jax.vmap(model, in_axes=(0, None), out_axes=(0, None), axis_name="batch")
    

def init_apply_eqx_model(
    model: tuple[Any, State]
) -> tuple[Callable, Callable]:
    model, state = model
    params, static = eqx.partition(model, eqx.is_inexact_array)

    def init():
        return params, state

    def apply(params, state, input):
        model = eqx.combine(params, static)
        batched_model = batch_model(model)
        out, updates = batched_model(input, state)
        return out, updates

    return init, apply


@eqx.nn.make_with_state
class Model(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    norm: eqx.nn.BatchNorm

    def __init__(self, key):
        key1, key2 = random.split(key)
        self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key1)
        self.norm = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key2)

    def __call__(self, x, state):
        x = self.linear1(x)
        x, state = self.norm(x, state)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state
    
eqx_model = Model(random.key(0))
init, apply = init_apply_eqx_model(eqx_model)
params, state = init()

# vmap model so that we can have input of shape e.g.
# [4, 10, 32]

# THIS WORKS
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=0)
out, update = new_apply(params, state, jnp.ones((4, 10, 32)))
print(out.shape)


# THIS DOES NOT WORK
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=(0, None))
new_apply(params, state, jnp.ones((4, 10, 32)))
# ERROR: 'ValueError: vmap has mapped output but out_axes is None'

PaulScemama avatar Apr 17 '24 01:04 PaulScemama