equinox
equinox copied to clipboard
Runtime penalty with fields that have expensive inits
I was surprised about the runtime of a posterior estimation I was coding. Here is a MWA to reproduce the issue:
import jax.numpy as jnp
import equinox as eqx
import distrax
class Parameter(eqx.Module):
value: jnp.ndarray
prior: distrax.Distribution
def log_prior(self):
return self.prior.log_prob(self.value)
v = jnp.zeros(10)
mu = jnp.ones(10)
sigma = jnp.ones(10)
p0 = distrax.MultivariateNormalDiag(mu, sigma)
p = Parameter(v, p0)
# test with some data
import optax
from jax import random
key = random.PRNGKey(0)
data = random.normal(key, (10,))
@eqx.filter_value_and_grad
def loss_fn_with_grad(model, data):
neg_log_like = 0.5 * ((model.value - data)**2).sum()
return neg_log_like - model.log_prior()
@eqx.filter_jit
def make_step(model, data, opt_state):
loss, grads = loss_fn_with_grad(model, data)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
learning_rate=1e-1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(p, eqx.is_array))
for step in range(100):
loss, p, opt_state = make_step(p, data, opt_state)
loss = loss.item()
print(f"step={step}, loss={loss}")
The runtime for a single call to make_step
is about 75 ms on a single core. This drops to 100 µs when the prior
is declared static:
class Parameter(eqx.Module):
value: jnp.ndarray
prior: distrax.Distribution = eqx.static_field()
def log_prior(self):
return self.prior.log_prob(self.value)
The issue, I think, is that in the first case, filtering creates new instances of prior
:
params, static = eqx.partition(p, eqx.is_array)
print(p)
print(params)
print(static)
yields
Parameter(
value=f32[10],
prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x13e9ae9d0>
)
Parameter(
value=f32[10],
prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15b82c400>
)
Parameter(
value=None,
prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15bd2d370>
)
In the case of prior: distrax.Distribution = eqx.static_field()
it's the same instance. I suspect that the cost to init the new instances is what's dragging down the performance. This is probably the intended behavior here but I was surprised by it because in neither case prior
is treated as a tree leaf, which I thought is what get's altered for each gradient update.
I might be doing this wrong, but if not it would be good to clarify the use of static_field
in the docs, where it's generally discouraged.