equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Runtime penalty with fields that have expensive inits

Open pmelchior opened this issue 1 year ago • 5 comments

I was surprised about the runtime of a posterior estimation I was coding. Here is a MWA to reproduce the issue:


import jax.numpy as jnp
import equinox as eqx
import distrax

class Parameter(eqx.Module):
    value: jnp.ndarray
    prior: distrax.Distribution

    def log_prior(self):
        return self.prior.log_prob(self.value)

v = jnp.zeros(10)
mu = jnp.ones(10)
sigma = jnp.ones(10)
p0 = distrax.MultivariateNormalDiag(mu, sigma)
p = Parameter(v, p0)

# test with some data
import optax
from jax import random
key = random.PRNGKey(0)
data = random.normal(key, (10,))

@eqx.filter_value_and_grad
def loss_fn_with_grad(model, data):
    neg_log_like = 0.5 * ((model.value - data)**2).sum()
    return neg_log_like - model.log_prior()

@eqx.filter_jit
def make_step(model, data, opt_state):
    loss, grads = loss_fn_with_grad(model, data)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

learning_rate=1e-1
optim = optax.adam(learning_rate)
opt_state = optim.init(eqx.filter(p, eqx.is_array))

for step in range(100):
    loss, p, opt_state = make_step(p, data, opt_state)
    loss = loss.item()
    print(f"step={step}, loss={loss}")

The runtime for a single call to make_step is about 75 ms on a single core. This drops to 100 µs when the prior is declared static:

class Parameter(eqx.Module):
    value: jnp.ndarray
    prior: distrax.Distribution = eqx.static_field()

    def log_prior(self):
        return self.prior.log_prob(self.value)

The issue, I think, is that in the first case, filtering creates new instances of prior:

params, static = eqx.partition(p, eqx.is_array)
print(p)
print(params)
print(static)

yields

Parameter(
  value=f32[10],
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x13e9ae9d0>
)
Parameter(
  value=f32[10],
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15b82c400>
)
Parameter(
  value=None,
  prior=<distrax._src.distributions.mvn_diag.MultivariateNormalDiag object at 0x15bd2d370>
)

In the case of prior: distrax.Distribution = eqx.static_field() it's the same instance. I suspect that the cost to init the new instances is what's dragging down the performance. This is probably the intended behavior here but I was surprised by it because in neither case prior is treated as a tree leaf, which I thought is what get's altered for each gradient update.

I might be doing this wrong, but if not it would be good to clarify the use of static_field in the docs, where it's generally discouraged.

pmelchior avatar Jan 08 '23 01:01 pmelchior