Enabling tracer leak check causes error
I have this bilevel optimization setup with equinox and optax.
When I set the tracer leak option with jax.config.update("jax_check_tracer_leaks", True),
it generates an error within equinox!
import equinox as eqx
import equinox.internal as eqxi
import optimistix as optx
import jax.numpy as jnp
import optax
import jax
from typing import Callable
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_check_tracer_leaks", True)
def xu_projection(design: jax.Array, beta: float, eta: float):
design = (jnp.tanh(beta * eta) + jnp.tanh(beta * (design - eta))) / (
jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta))
)
return design
def xu_volume_preserving_filter(
design: jax.Array,
beta: float,
target_vf: float = 0.5,
) -> jax.Array:
def transform_design(design, eta):
# Sigmoidal transform
transformed = xu_projection(design, beta, eta)
return transformed
def volume_violation_fn(y, args):
loss = transform_design(args, y).mean() - target_vf
return loss
bisection = optx.Bisection(atol=1e-9, rtol=1e-9)
sol = optx.root_find(
fn=volume_violation_fn,
solver=bisection,
y0=0.5,
args=design,
options=dict(lower=0, upper=1),
max_steps=750,
throw=False,
)
return transform_design(design, sol.value)
class PixelModel(eqx.Module):
params: jnp.ndarray
param_static_filter: callable
def __init__(self, Nx, Ny, target_vf=0.5,
param_static_filter=eqx.is_array):
self.param_static_filter = param_static_filter
# create trainable parameters
self.params = jnp.ones((Ny, Nx)) * target_vf
def __call__(self, x=None):
return self.params.ravel().reshape(-1, 1)
class InnerState(eqx.Module):
model: eqx.Module
optimizer_state: optax.OptState
hyperparams: tuple
def __init__(self, model, optimizer_state, hyperparams):
self.model = model
self.optimizer_state = optimizer_state
self.hyperparams = hyperparams
def inner_optimization_step(state, loss_fn, optimizer):
params, static = eqx.partition(state.model, eqx.is_array)
(loss, aux), grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(
params, static, state.hyperparams)
def loss_fn_scalar(p):
return loss_fn(p, static, state.hyperparams)[0]
updates, new_opt_state = optimizer.update(
grads, state.optimizer_state, params, value=loss,
grad=grads,
value_fn=loss_fn_scalar)
new_params = optax.apply_updates(params, updates)
new_model = eqx.combine(new_params, static)
new_opt_state = jax.tree.map(lambda x, y: x.astype(
y.dtype), new_opt_state, state.optimizer_state)
return InnerState(new_model, new_opt_state, state.hyperparams), aux
class BilevelObjective(eqx.Module):
inner_model: eqx.Module
inner_loss_fn: Callable
inner_optimizer: optax.GradientTransformation = eqx.field(
converter=eqxi.closure_to_pytree)
outer_loss_fn: Callable
def init_inner_state(self, outer_params):
inner_params, inner_static = eqx.partition(
self.inner_model, self.inner_model.param_static_filter)
inner_opt_state = self.inner_optimizer.init(inner_params)
return InnerState(eqx.combine(inner_params, inner_static),
inner_opt_state,
outer_params)
def calculate_outer_loss(self, outer_params,
num_inner_steps,
):
"""Computes outer loss after inner optimization."""
state = self.init_inner_state(outer_params)
final_state, inner_aux = self.run_inner_optimization(
state, num_inner_steps, self.inner_loss_fn,
self.inner_optimizer)
optimized_inner_params, inner_static = eqx.partition(
final_state.model, final_state.model.param_static_filter)
loss, outer_aux = self.outer_loss_fn(
optimized_inner_params, inner_static,
final_state.hyperparams)
return loss, (outer_aux, inner_aux, final_state)
def run_inner_optimization(self, state, num_inner_steps, loss_fn,
optimizer):
"""Runs inner optimization."""
for _ in range(num_inner_steps):
state, aux = inner_optimization_step(state, loss_fn, optimizer)
return state, aux
def main():
domain_shape = (10, 10)
inner_model = PixelModel(
Nx=domain_shape[0], Ny=domain_shape[1], target_vf=0.5,
param_static_filter=eqx.is_array)
inner_optimizer = optax.adam(learning_rate=1e-3, eps_root=1e-8)
def design_parameterisation(params, static, beta):
model = eqx.combine(params, static)
rho = model(None)
# rho = density_filter(rho)
rho = xu_volume_preserving_filter(rho, beta=beta, target_vf=0.5)
return rho.reshape(domain_shape)
def inner_loss_fn(params, static, hparams):
beta = hparams[1] # hparams.get_scalar("proj_beta")
penalty = hparams[0] # hparams.get_scalar("simp_penalty")
rho = design_parameterisation(params, static,
beta=beta)
compliance = jnp.mean(rho**2) # solver(rho, sample, penalty)
loss = compliance
return loss, {"design": rho,
"penalty": penalty,
"beta": beta,
"compliance": compliance, }
def outer_loss_fn(params, static, hparams):
beta = hparams[1]
penalty = 1.0
rho = design_parameterisation(params, static,
beta=beta) # 1.0)
comp_p_1 = jnp.mean(rho**3) # solver(rho, sample, penalty)
loss = comp_p_1
return loss, {"design_star": rho,
"comp_p_1": comp_p_1,
"beta": beta,
"penalty": penalty}
outer_objective = BilevelObjective(
inner_model=inner_model,
inner_loss_fn=inner_loss_fn,
inner_optimizer=inner_optimizer,
outer_loss_fn=outer_loss_fn)
outer_params = (jnp.array(0.5), jnp.array(1.0))
num_inner_steps = 5
loss, aux = outer_objective.calculate_outer_loss(
outer_params, num_inner_steps)
print("Outer loss:", loss)
if __name__ == "__main__":
main()
Optax version : 0.2.4 Equinox : 0.11.12 JAX: 0.5.1 Optimistic: 0.0.9
Hi, thanks for the report! Getting an IndexError suggests that it may actually be an error within JAX's own error-checking machinery.
Either way, can you try minimising this down to the smallest-possible reproduction? This will really help us to identify what's going wrong. From your attached stack trace I can see that it's likely coming from within optimistix.Bisection.init.
Please also be sure to update to the latest version of each of these libraries (Equinox 0.12.1, JAX 0.5.3, Optimistix 0.0.10).
Hi Patrick,
I upgraded everything and still the issue is there. I managed to reduce the example further! Will this be okay?
import equinox as eqx
import optimistix as optx
import jax.numpy as jnp
import optax
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_check_tracer_leaks", True)
def xu_projection(design: jax.Array, beta: float, eta: float):
design = (jnp.tanh(beta * eta) + jnp.tanh(beta * (design - eta))) / (
jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta)))
return design
def xu_volume_preserving_filter(
design: jax.Array,
beta: float,
target_vf: float = 0.5,
) -> jax.Array:
def transform_design(design, eta):
# Sigmoidal transform
transformed = xu_projection(design, beta, eta)
return transformed
def volume_violation_fn(y, args):
loss = transform_design(args, y).mean() - target_vf
return loss
bisection = optx.Bisection(atol=1e-3, rtol=1e-3)
sol = optx.root_find(
fn=volume_violation_fn,
solver=bisection,
y0=0.5,
args=design,
options=dict(lower=0, upper=1),
max_steps=750,
throw=False,)
return transform_design(design, sol.value)
def inner_optimization_step(state, loss_fn, optimizer):
_, grads = eqx.filter_value_and_grad(loss_fn)(
state[0], state[2])
updates, new_opt_state = optimizer.update(
grads, state[1])
new_params = optax.apply_updates(state.model, updates)
return (new_params, new_opt_state, state[2])
def run_inner_optimization(outer_params,
num_inner_steps,
inner_model,
inner_loss_fn,
inner_optimizer):
inner_params = inner_model
inner_opt_state = inner_optimizer.init(inner_params)
state = (inner_params,
inner_opt_state,
outer_params)
for _ in range(num_inner_steps):
state = inner_optimization_step(state, inner_loss_fn, inner_optimizer)
return state
def main():
domain_shape = (10, 10)
inner_model = jnp.ones((domain_shape[0], domain_shape[1])) * 0.5
inner_optimizer = optax.adam(learning_rate=1e-3, eps_root=1e-8)
def inner_loss_fn(params, hparams):
rho = xu_volume_preserving_filter(params, beta=hparams, target_vf=0.5)
loss = jnp.mean(rho**2)
return loss
outer_params = jnp.array(50.0)
num_inner_steps = 10
final_state = run_inner_optimization(
outer_params, num_inner_steps, inner_model, inner_loss_fn, inner_optimizer)
print("Final state:", final_state)
if __name__ == "__main__":
main()
Hi Suryanarayanan,
I tried running this - which required some edits, presumably because state is no longer some container with a model attribute:
def inner_optimization_step(state, loss_fn, optimizer):
inner_params, inner_opt_state, outer_params = state
_, grads = eqx.filter_value_and_grad(loss_fn)(inner_params, outer_params)
updates, new_opt_state = optimizer.update(grads, inner_opt_state)
new_params = optax.apply_updates(inner_params, updates)
return (new_params, new_opt_state, outer_params)
def run_inner_optimization(
outer_params, num_inner_steps, inner_model, inner_loss_fn, inner_optimizer
):
inner_params = inner_model
inner_opt_state = inner_optimizer.init(inner_params)
state = (inner_params, inner_opt_state, outer_params)
for _ in range(num_inner_steps):
state = inner_optimization_step(state, inner_loss_fn, inner_optimizer)
return state
(inner_model is now just a pytree of parameters, renamed inner_params, and accordingly inner_params is what must be passed to optax.apply_updates, rather than state.model, which no longer exists.)
I don't get any errata on the newest versions of the specified libraries.
Hope this helps! Perhaps you can backtrack from here to the version that did give you errata, and figure out what went wrong along the way :)
Hey,
I tried it a new environment (btw I am running it on a Mac but with the cpu version of jax).
I still get the error. I changed the solver to Newton but the error persists.
Seems to be from within JAX..
File ~/miniforge3/envs/metatopia_new/lib/python3.11/site-packages/lineax/_solve.py:114, in _linear_solve_impl(_, state, vector, options, solver, throw, check_closure)
108 result = RESULTS.where(
109 (result == RESULTS.singular) & has_nonfinite_input,
110 RESULTS.nonfinite_input,
111 result,
112 )
113 if throw:
--> 114 solution, result, stats = result.error_if(
115 (solution, result, stats),
116 result != RESULTS.successful,
117 )
118 return solution, result, stats
[... skipping hidden 15 frame]
File ~/miniforge3/envs/metatopia_new/lib/python3.11/contextlib.py:144, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
142 if typ is None:
143 try:
--> 144 next(self.gen)
145 except StopIteration:
146 return False
[... skipping hidden 3 frame]
File ~/miniforge3/envs/metatopia_new/lib/python3.11/site-packages/jax/_src/core.py:1305, in _why_alive(ignore_ids, x)
1292 parent = parents(child)[0] # just pick one parent
1294 # For namespaces (like modules and class instances) and closures, the
1295 # references may form a simple chain: e.g. instance refers to its own
1296 # __dict__ which refers to child, or function refers to its __closure__
(...) 1302 # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
1303 # To prevent this collapsing behavior, just comment out this code block.
1304 if (isinstance(parent, dict) and
-> 1305 getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
1306 parent = parents(parent)[0]
1307 elif type(parent) is types.CellType:
IndexError: list index out of range
Unfortunately I cannot reproduce the error with the last example you posted. With that, I only get an AttributeError that is easily fixed.
I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: https://github.com/jax-ml/jax/issues/26560
I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: jax-ml/jax#26560
Which one did you try, the first or the second example? If we can figure out what the differences between your guys' and my setup is, then this would be a step forward! I used the somewhat reduced example @SNMS95 posted yesterday, not the longer one from three days ago.
jax 0.5.3, equinox 0.12.1, optimistix 0.0.10 (on current main). MacBook with M1 processor.
I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: jax-ml/jax#26560
Which one did you try, the first or the second example? If we can figure out what the differences between your guys' and my setup is, then this would be a step forward! I used the somewhat reduced example @SNMS95 posted yesterday, not the longer one from three days ago.
jax 0.5.3, equinox 0.12.1, optimistix 0.0.10 (on current main). MacBook with M1 processor.
I tried it with the following settings: Conda environment with pip installed versions of: python=3.11 jax=0.5.3 equinox=0.12.1 optimistix=0.0.10 on Mac with M3 processor
I can reproduce this with Python 3.11.9, but not with Python 3.13.1. I could reduce the MWE to this:
import jax
import optimistix as optx
jax.config.update("jax_check_tracer_leaks", True)
def match(y, args): # Has trivial root at 0.5
del args
return y - 0.5
solver = optx.Bisection(rtol=1e-3, atol=1e-3)
y0 = 0.75
optx.root_find(match, solver, y0, options=dict(lower=0, upper=1))
This most probably needs a fix over in optimistix, I can look into it tomorrow. @patrick-kidger do you have suggestions as to what changed in between these Python versions? The root finder can be swapped for another one, enabling the checks for tracer leaks is required.
Alright, this is not an optimistix error after all. The IndexError is caused by an interaction with equinox runtime errata and Python 3.11, works on Python 3.13 (haven't checked for 3.12).
import equinox as eqx
import jax
import jax.numpy as jnp
jax.config.update("jax_check_tracer_leaks", True)
x1 = jnp.ones(3)
x2 = eqx.error_if(x1, jnp.array(True), "This should not be an IndexError") # pred can be False or True
Interesting! This Equinox-only MWE here is super useful.
As another preliminary result on my side, I can avoid this error by dropping back to JAX 0.4.35, so it seems that this is something that came in with JAX's 'stackless' change. I can also avoid this error by removing our JIT-wrapped around branched_error_if_impl. Probably the next step is to further simplify this MWE to avoid Equinox too.
Thanks for the pointers! I could get to a JAX-only thing by removing everything but the traceback walk from branched_error_if_impl.