equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Combining Modules, vmapping, and lax.switch (jit versus filter_jit)

Open djbower opened this issue 8 months ago • 4 comments

I have a design pattern as below, which works when I use @jit on the class method func but fails with a TypeError when I use @eqx.filter_jit. Perhaps naively I was expecting I could switch out all my previous @jit decorators as I switch my code base completely over to Equinox. But evidently here there is still something I don't fully understand. I'm not sure whether my design pattern is sub-optimal for use with Equinox or whether I need to add additional Equinox decorators. Or whether there is good reason to simply stick with the @jit decorator which works.

The Module documentation (https://docs.kidger.site/equinox/api/module/module/) interestingly also has a @jit decorator, although there isn't much explanation to clarify why this and not @eqx.filter_jit.

from typing import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array, jit, lax
from jaxtyping import ArrayLike

jax.config.update("jax_enable_x64", True)


class SomeClass1(eqx.Module):
    some_array: ArrayLike

    # Code breaks (TypeError: unhashable type: 'DynamicJaxprTracer') when using:
    # @eqx.filter_jit
    @jit
    def func(self, x) -> Array:
        del x
        return jnp.log(self.some_array)


class SomeClass2(eqx.Module):
    some_array: ArrayLike

    # Code breaks (TypeError: unhashable type: 'DynamicJaxprTracer') when using:
    # @eqx.filter_jit
    @jit
    def func(self, x) -> Array:
        return jnp.asarray(self.some_array) * x


class Container(eqx.Module):
    data: tuple

    @eqx.filter_jit
    def func(self, x) -> Array:
        functions = tuple([module.func for module in self.data])

        def apply_func(index: ArrayLike, x) -> Array:
            return lax.switch(index, functions, x)

        vmap_apply_func: Callable = eqx.filter_vmap(apply_func, in_axes=(0, None))
        indices = jnp.arange(len(self.data))
        values = vmap_apply_func(indices, x)

        return values


x = jnp.array([1, 2, 3], dtype=jnp.float_)
a = jnp.array([1, 2, 3], dtype=jnp.float_)
b = jnp.array([4, 5, 6], dtype=jnp.float_)

modules = (SomeClass1(a), SomeClass2(b))
container = Container(modules)
out = container.func(x)
print("out = ", out)

djbower avatar Apr 25 '25 19:04 djbower

Two main things, first, this happens because filter_jit is actually transforming the function into a Partial Module with the member variable being part of it. So it goes from

<bound method SomeClass1.func of SomeClass1(some_array=f64[3])>

to

Partial(
  func=_JitWrapper(
    fn='SomeClass1.func',
    filter_warning=False,
    donate_first=False,
    donate_rest=False
  ),
  args=(SomeClass1(some_array=f64[3]),),
  keywords={}
)

with filter jit. Modules are hashed based on their leaves, but this has a jax array, which isn't hashable, which is why it throws an error (but jit doesn't) since lax.switch identifies functions based on their hashes.

Second thing is that vmapping over a switch (which indices are being switched) results in all branches actually being computed, so that's just something to be aware of/careful of in some cases.

From a design pattern, I'm not sure there's right or wrong answers (or at least I don't think so, as long as it achieves the goal of the software), but since you basically want functions that are hashable, I would define a hash (like below, although maybe a more informative hash for your situation) and just use the modules (if that is possible conceptually for these modules)

code
from typing import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
from jax import Array, jit, lax
from jaxtyping import ArrayLike

jax.config.update("jax_enable_x64", True)


class SomeClass1(eqx.Module):
    some_array: ArrayLike

    def __call__(self, x) -> Array:
        del x
        return jnp.log(self.some_array)

    def __hash__(self):
        return 0


class SomeClass2(eqx.Module):
    some_array: ArrayLike

    def __call__(self, x) -> Array:
        return jnp.asarray(self.some_array) * x

    def __hash__(self):
        return 1



class Container(eqx.Module):
    data: tuple

    @eqx.filter_jit
    def func(self, x) -> Array:
        functions = self.data

        def apply_func(index: ArrayLike, x) -> Array:
            return lax.switch(index, functions, x)

        vmap_apply_func: Callable = eqx.filter_vmap(apply_func, in_axes=(0, None))
        indices = jnp.arange(len(self.data))
        values = vmap_apply_func(indices, x)

        return values


x = jnp.array([1, 2, 3], dtype=jnp.float_)
a = jnp.array([1, 2, 3], dtype=jnp.float_)
b = jnp.array([4, 5, 6], dtype=jnp.float_)

modules = (SomeClass1(a), SomeClass2(b))
container = Container(modules)
out = container.func(x)
print("out = ", out)

lockwo avatar Apr 25 '25 20:04 lockwo

Thanks for the question! So this is a somewhat subtle point, and arguably a minor (easy to work around) bug in JAX itself.

First of all, here's a smaller repro of the example you're demonstrating:

import equinox as eqx
import jax

class M(eqx.Module):
    x: jax.Array

    def f(self, y):
        return self.x + y

@jax.jit
def g(x):
    fx = M(x).f
    return jax.lax.switch(0, [fx, fx], x)

x = jax.numpy.array([1, 2])
g(x)

and now here's the solution:

import equinox as eqx
import jax

class M(eqx.Module):
    x: jax.Array

    def f(self, y):
        return self.x + y

@jax.jit
def g(x):
    fx = lambda y: M(x).f(y)  # THIS LINE CHANGED
    return jax.lax.switch(0, [fx, fx], x)

x = jax.numpy.array([1, 2])
g(x)

What's happening here is that jax.lax.switch attempts to hash the functions that it is passed. In this case, these are the bound methods SomeClass1(a).func and SomeClass2(b).func. These are pytrees containing the arrays a and b, which aren't hashable, and so an error is thrown.

  • Bound methods of eqx.Modules are pytrees, in which the attributes of the module are nodes.
  • If you have it, then the wrapped function from eqx.filter_jit is also a pytree, in which the callable it wraps is a node.

Wrapping these into a lambda -- which is hashable -- makes things work again.

When you add a jax.jit decorator then you're doing something essentially identical to wrapping in a lambda: it's just acting as a hashable function wrapper. (Indeed the root cause here has actually little to do with jax.jit vs eqx.filter_jit, and everything to do with hashable function wrappers.)

I hope that helps!


Some extra notes.

  1. Equinox makes damn near everything a pytree, so as to ensure correctness around operations like jax.grad etc: we want to be certain that all of the arrays we pass are made visible to these transformations, and not 'smuggled in'. This kind of smuggling is what happens when you have jax.jit, which simply returns a regular function.

    Most of the time this distinction doesn't matter, as here, but with sufficient inventiveness it's possible to shoot yourself in the foot due to a lack of pytree-ness.

  2. As a general JAX comment (whether using Equinox or not), then you actually shouldn't need these jax.jit decorators at all. See point 1 in this blog post: you only need to JIT-compile your topmost function, everything else is ignored. Sprinkling JITs around anywhere else is mostly decorative. Also see this example for reference, in which we just jit our make_step function.

(EDIT: race condition against @lockwo!)

patrick-kidger avatar Apr 25 '25 20:04 patrick-kidger

Just building on your point 1., for complete generality and JAX compatibility is it better to use Partial rather than the lambda expression?:

from jax.tree_util import Partial

fx = Partial(M(x).f)

djbower avatar Apr 26 '25 09:04 djbower

At least here, the choice between a lambda wrapper and a jax.tree_util.Partial wrapper doesn't matter at all.

If you're after a more general rule to avoid footguns, then:

  • For places where you need a hashable wrapper (basically the functions passed to higher order primitives: jax.lax.{switch, scan, while_loop} etc) then a lambda is fine.
  • For every other use-case then I'd recommend eqx.Partial instead. (The difference is that jax.tree_util.Partial(fn, *args, **kwargs) has only its args and kwargs as nodes but 'closes over' fn. Meanwhile eqx.Partial(fn, *args, **kwargs) has all of fn, *args, **kwargs as nodes. I do think it's a little unfortunate that we have all of functools.partial, jax.tree_util.Partial and equinox.Partial -- whilst I don't have a better solution, I don't love how using JAX makes this kind of detail a confusing experience.)

patrick-kidger avatar Apr 26 '25 09:04 patrick-kidger