equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Weird `jax` error when trying vmap twice while using batchnorm

Open PaulScemama opened this issue 2 years ago • 5 comments

I'm trying to

  1. Utilize the init/apply design pattern.
  2. Vmap a module twice (where the module uses batchnorm) in order to handle inputs of shape [n_samples, batch_dim, feature_dim] instead of just [batch_dim, feature_dim].
import jax

import jax.numpy as jnp
import jax.random as random

from typing import Any, Callable
import equinox as eqx
State = eqx.nn._stateful.State

def batch_model(model: eqx.Module) -> Callable:
    # see BatchNorm: https://docs.kidger.site/equinox/api/nn/normalisation/
    return jax.vmap(model, in_axes=(0, None), out_axes=(0, None), axis_name="batch")
    

def init_apply_eqx_model(
    model: tuple[Any, State]
) -> tuple[Callable, Callable]:
    model, state = model
    params, static = eqx.partition(model, eqx.is_inexact_array)

    def init():
        return params, state

    def apply(params, state, input):
        model = eqx.combine(params, static)
        batched_model = batch_model(model)
        out, updates = batched_model(input, state)
        return out, updates

    return init, apply


@eqx.nn.make_with_state
class Model(eqx.Module):
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear
    norm: eqx.nn.BatchNorm

    def __init__(self, key):
        key1, key2 = random.split(key)
        self.linear1 = eqx.nn.Linear(in_features=32, out_features=32, key=key1)
        self.norm = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key2)

    def __call__(self, x, state):
        x = self.linear1(x)
        x, state = self.norm(x, state)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state
    
eqx_model = Model(random.key(0))
init, apply = init_apply_eqx_model(eqx_model)
params, state = init()

# vmap model so that we can have input of shape e.g.
# [4, 10, 32]

# THIS WORKS
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=0)
out, update = new_apply(params, state, jnp.ones((4, 10, 32)))
print(out.shape)


# THIS DOES NOT WORK
new_apply = jax.vmap(apply, in_axes=(None, None, 0), out_axes=(0, None))
new_apply(params, state, jnp.ones((4, 10, 32)))
# ERROR: 'ValueError: vmap has mapped output but out_axes is None'

PaulScemama avatar Apr 17 '24 01:04 PaulScemama

I think this is expected. The "inner" vmap is the one being addressed by the BatchNorm (because they have matching axis_names), so the "outer" vmap is over the whole operation, state-of-batch-norm included. Thus, you should get a batch-of-states as output.

If you're looking to have BatchNorm operate over multiple vmaps then this is doable like so:

BatchNorm(..., axis_name=("batch", "batch2"))

...

jax.vmap(apply, ..., axis_name="batch2")

by passing a tuple-of-axis-names to BatchNorm then it will know to operate across both of those transformations.

patrick-kidger avatar Apr 19 '24 22:04 patrick-kidger

@patrick-kidger thanks for the response! That makes sense.

Just fiddling around with things and I noticed the following yields an error:

from dataclasses import replace

l = eqx.nn.BatchNorm(14, axis_name="blah")
new_l = replace(l, axis_name="foo")
# TypeError: __init__() got an unexpected keyword argument 'weight'

But this is not specific to BatchNorm

l = eqx.nn.Linear(5, 3, key=random.key(1))
new_l = replace(l, in_features=5)
# TypeError: __init__() got an unexpected keyword argument 'weight'

And seems to be rather to do with more general _ModuleMeta class. I took a quick peak at the _module.py but couldn't immediately discern any comments how to create a new copy of a _ModuleMeta class while replacing a single (or multiple) attributes.

Is this possible to do?

Many thanks!

PaulScemama avatar Apr 23 '24 23:04 PaulScemama

And seems to be rather to do with more general _ModuleMeta class. I took a quick peak at the _module.py but couldn't immediately discern any comments how to create a new copy of a _ModuleMeta class while replacing a single (or multiple) attributes.

Is this possible to do?

Many thanks!

I believe there is equinox.tree_at that can replace the attributes.

nasyxx avatar Apr 24 '24 04:04 nasyxx

Yup, this is actually a known issue with dataclasses.replace; it doesn't respect custom __init__s.

Indeed use eqx.tree_at instead!

patrick-kidger avatar Apr 24 '24 06:04 patrick-kidger

Great. Thank you both!

PaulScemama avatar Apr 25 '24 12:04 PaulScemama