equinox
equinox copied to clipboard
Recompilation of purely static `Module` bound methods
Consider the following code:
import jax
import jax.numpy as jnp
import equinox as eqx
from functools import partial
class Model(eqx.Module):
def f(self):
return
model = Model()
# We get a different function (Partial function) each time we __get__ a bound method
f1 = model.f
f2 = model.f
assert not (f1 is f2)
print(hash(f1))
print(hash(f2))
# Let's try using the bound method as a static argument; now jax.jit cares only about __hash__ and __eq__ of the static argument
@partial(jax.jit, static_argnums=0)
def apply(f):
return f()
# This causes a recompilation since the hash of f1 and f2 are different
with jax.log_compiles(True):
apply(f1)
apply(f2)
tl;dr bound methods have a different hash each time they are accessed because of this wrapper: https://github.com/patrick-kidger/equinox/blob/2cf6fecb1fb275a13d4931ec79fa91e4f38cc315/equinox/module.py#L64. We should probably memoize the result of __get__
so that we don't get a new Partial
every time we access a bound method (assuming the instance of Module
is immutable).