equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Recompilation of purely static `Module` bound methods

Open KeAWang opened this issue 1 year ago • 3 comments

Consider the following code:

import jax
import jax.numpy as jnp
import equinox as eqx
from functools import partial

class Model(eqx.Module):
    def f(self):
        return
    
model = Model()
# We get a different function (Partial function) each time we __get__ a bound method
f1 = model.f
f2 = model.f
assert not (f1 is f2)
print(hash(f1))
print(hash(f2))

# Let's try using the bound method as a static argument; now jax.jit cares only about __hash__ and __eq__ of the static argument
@partial(jax.jit, static_argnums=0)
def apply(f):
    return f()

# This causes a recompilation since the hash of f1 and f2 are different
with jax.log_compiles(True):
    apply(f1)
    apply(f2)

tl;dr bound methods have a different hash each time they are accessed because of this wrapper: https://github.com/patrick-kidger/equinox/blob/2cf6fecb1fb275a13d4931ec79fa91e4f38cc315/equinox/module.py#L64. We should probably memoize the result of __get__ so that we don't get a new Partial every time we access a bound method (assuming the instance of Module is immutable).

KeAWang avatar Mar 16 '23 21:03 KeAWang