equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Enabling tracer leak check causes error

Open SNMS95 opened this issue 8 months ago • 12 comments

I have this bilevel optimization setup with equinox and optax. When I set the tracer leak option with jax.config.update("jax_check_tracer_leaks", True), it generates an error within equinox!

import equinox as eqx
import equinox.internal as eqxi
import optimistix as optx
import jax.numpy as jnp
import optax
import jax
from typing import Callable

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_check_tracer_leaks", True)


def xu_projection(design: jax.Array, beta: float, eta: float):
    design = (jnp.tanh(beta * eta) + jnp.tanh(beta * (design - eta))) / (
        jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta))
    )
    return design


def xu_volume_preserving_filter(
    design: jax.Array,
    beta: float,
    target_vf: float = 0.5,
) -> jax.Array:

    def transform_design(design, eta):
        # Sigmoidal transform
        transformed = xu_projection(design, beta, eta)
        return transformed

    def volume_violation_fn(y, args):
        loss = transform_design(args, y).mean() - target_vf
        return loss

    bisection = optx.Bisection(atol=1e-9, rtol=1e-9)

    sol = optx.root_find(
        fn=volume_violation_fn,
        solver=bisection,
        y0=0.5,
        args=design,
        options=dict(lower=0, upper=1),
        max_steps=750,
        throw=False,
    )
    return transform_design(design, sol.value)


class PixelModel(eqx.Module):
    params: jnp.ndarray
    param_static_filter: callable

    def __init__(self, Nx, Ny, target_vf=0.5,
                 param_static_filter=eqx.is_array):
        self.param_static_filter = param_static_filter
        # create trainable parameters
        self.params = jnp.ones((Ny, Nx)) * target_vf

    def __call__(self, x=None):
        return self.params.ravel().reshape(-1, 1)


class InnerState(eqx.Module):
    model: eqx.Module
    optimizer_state: optax.OptState
    hyperparams: tuple

    def __init__(self, model, optimizer_state, hyperparams):
        self.model = model
        self.optimizer_state = optimizer_state
        self.hyperparams = hyperparams


def inner_optimization_step(state, loss_fn, optimizer):
    params, static = eqx.partition(state.model, eqx.is_array)
    (loss, aux), grads = eqx.filter_value_and_grad(loss_fn, has_aux=True)(
        params, static, state.hyperparams)

    def loss_fn_scalar(p):
        return loss_fn(p, static, state.hyperparams)[0]

    updates, new_opt_state = optimizer.update(
        grads, state.optimizer_state, params, value=loss,
        grad=grads,
        value_fn=loss_fn_scalar)
    new_params = optax.apply_updates(params, updates)
    new_model = eqx.combine(new_params, static)
    new_opt_state = jax.tree.map(lambda x, y: x.astype(
        y.dtype), new_opt_state, state.optimizer_state)
    return InnerState(new_model, new_opt_state, state.hyperparams), aux


class BilevelObjective(eqx.Module):
    inner_model: eqx.Module
    inner_loss_fn: Callable
    inner_optimizer: optax.GradientTransformation = eqx.field(
        converter=eqxi.closure_to_pytree)
    outer_loss_fn: Callable

    def init_inner_state(self, outer_params):
        inner_params, inner_static = eqx.partition(
            self.inner_model, self.inner_model.param_static_filter)
        inner_opt_state = self.inner_optimizer.init(inner_params)
        return InnerState(eqx.combine(inner_params, inner_static),
                          inner_opt_state,
                          outer_params)

    def calculate_outer_loss(self, outer_params,
                             num_inner_steps,
                             ):
        """Computes outer loss after inner optimization."""
        state = self.init_inner_state(outer_params)
        final_state, inner_aux = self.run_inner_optimization(
            state, num_inner_steps, self.inner_loss_fn,
            self.inner_optimizer)
        optimized_inner_params, inner_static = eqx.partition(
            final_state.model, final_state.model.param_static_filter)
        loss, outer_aux = self.outer_loss_fn(
            optimized_inner_params, inner_static,
            final_state.hyperparams)
        return loss, (outer_aux, inner_aux, final_state)

    def run_inner_optimization(self, state, num_inner_steps, loss_fn,
                               optimizer):
        """Runs inner optimization."""
        for _ in range(num_inner_steps):
            state, aux = inner_optimization_step(state, loss_fn, optimizer)
        return state, aux


def main():
    domain_shape = (10, 10)
    inner_model = PixelModel(
        Nx=domain_shape[0], Ny=domain_shape[1], target_vf=0.5,
        param_static_filter=eqx.is_array)

    inner_optimizer = optax.adam(learning_rate=1e-3, eps_root=1e-8)

    def design_parameterisation(params, static, beta):
        model = eqx.combine(params, static)
        rho = model(None)
        # rho = density_filter(rho)
        rho = xu_volume_preserving_filter(rho, beta=beta, target_vf=0.5)
        return rho.reshape(domain_shape)

    def inner_loss_fn(params, static, hparams):
        beta = hparams[1]  # hparams.get_scalar("proj_beta")
        penalty = hparams[0]  # hparams.get_scalar("simp_penalty")
        rho = design_parameterisation(params, static,
                                      beta=beta)
        compliance = jnp.mean(rho**2)  # solver(rho, sample, penalty)
        loss = compliance

        return loss, {"design": rho,
                      "penalty": penalty,
                      "beta": beta,
                      "compliance": compliance, }

    def outer_loss_fn(params, static, hparams):
        beta = hparams[1]
        penalty = 1.0
        rho = design_parameterisation(params, static,
                                      beta=beta)  # 1.0)
        comp_p_1 = jnp.mean(rho**3)  # solver(rho, sample, penalty)
        loss = comp_p_1
        return loss, {"design_star": rho,
                      "comp_p_1": comp_p_1,
                      "beta": beta,
                      "penalty": penalty}
    outer_objective = BilevelObjective(
        inner_model=inner_model,
        inner_loss_fn=inner_loss_fn,
        inner_optimizer=inner_optimizer,
        outer_loss_fn=outer_loss_fn)
    outer_params = (jnp.array(0.5), jnp.array(1.0))
    num_inner_steps = 5
    loss, aux = outer_objective.calculate_outer_loss(
        outer_params, num_inner_steps)
    print("Outer loss:", loss)


if __name__ == "__main__":
    main()

Image

Optax version : 0.2.4 Equinox : 0.11.12 JAX: 0.5.1 Optimistic: 0.0.9

SNMS95 avatar Apr 11 '25 16:04 SNMS95

Hi, thanks for the report! Getting an IndexError suggests that it may actually be an error within JAX's own error-checking machinery.

Either way, can you try minimising this down to the smallest-possible reproduction? This will really help us to identify what's going wrong. From your attached stack trace I can see that it's likely coming from within optimistix.Bisection.init.

Please also be sure to update to the latest version of each of these libraries (Equinox 0.12.1, JAX 0.5.3, Optimistix 0.0.10).

patrick-kidger avatar Apr 11 '25 16:04 patrick-kidger

Hi Patrick,

I upgraded everything and still the issue is there. I managed to reduce the example further! Will this be okay?

import equinox as eqx
import optimistix as optx
import jax.numpy as jnp
import optax
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_check_tracer_leaks", True)


def xu_projection(design: jax.Array, beta: float, eta: float):
    design = (jnp.tanh(beta * eta) + jnp.tanh(beta * (design - eta))) / (
        jnp.tanh(beta * eta) + jnp.tanh(beta * (1 - eta)))
    return design


def xu_volume_preserving_filter(
    design: jax.Array,
    beta: float,
    target_vf: float = 0.5,
) -> jax.Array:

    def transform_design(design, eta):
        # Sigmoidal transform
        transformed = xu_projection(design, beta, eta)
        return transformed

    def volume_violation_fn(y, args):
        loss = transform_design(args, y).mean() - target_vf
        return loss

    bisection = optx.Bisection(atol=1e-3, rtol=1e-3)

    sol = optx.root_find(
        fn=volume_violation_fn,
        solver=bisection,
        y0=0.5,
        args=design,
        options=dict(lower=0, upper=1),
        max_steps=750,
        throw=False,)
    return transform_design(design, sol.value)


def inner_optimization_step(state, loss_fn, optimizer):
    _, grads = eqx.filter_value_and_grad(loss_fn)(
        state[0], state[2])
    updates, new_opt_state = optimizer.update(
        grads, state[1])
    new_params = optax.apply_updates(state.model, updates)
    return (new_params, new_opt_state, state[2])


def run_inner_optimization(outer_params,
                           num_inner_steps,
                           inner_model,
                           inner_loss_fn,
                           inner_optimizer):
    inner_params = inner_model
    inner_opt_state = inner_optimizer.init(inner_params)
    state = (inner_params,
             inner_opt_state,
             outer_params)
    for _ in range(num_inner_steps):
        state = inner_optimization_step(state, inner_loss_fn, inner_optimizer)
    return state


def main():
    domain_shape = (10, 10)
    inner_model = jnp.ones((domain_shape[0], domain_shape[1])) * 0.5
    inner_optimizer = optax.adam(learning_rate=1e-3, eps_root=1e-8)

    def inner_loss_fn(params, hparams):
        rho = xu_volume_preserving_filter(params, beta=hparams, target_vf=0.5)
        loss = jnp.mean(rho**2)
        return loss

    outer_params = jnp.array(50.0)
    num_inner_steps = 10
    final_state = run_inner_optimization(
        outer_params, num_inner_steps, inner_model, inner_loss_fn, inner_optimizer)
    print("Final state:", final_state)


if __name__ == "__main__":
    main()

SNMS95 avatar Apr 13 '25 18:04 SNMS95

Hi Suryanarayanan,

I tried running this - which required some edits, presumably because state is no longer some container with a model attribute:

def inner_optimization_step(state, loss_fn, optimizer):
    inner_params, inner_opt_state, outer_params = state
    _, grads = eqx.filter_value_and_grad(loss_fn)(inner_params, outer_params)
    updates, new_opt_state = optimizer.update(grads, inner_opt_state)
    new_params = optax.apply_updates(inner_params, updates)
    return (new_params, new_opt_state, outer_params)


def run_inner_optimization(
    outer_params, num_inner_steps, inner_model, inner_loss_fn, inner_optimizer
):
    inner_params = inner_model
    inner_opt_state = inner_optimizer.init(inner_params)
    state = (inner_params, inner_opt_state, outer_params)
    for _ in range(num_inner_steps):
        state = inner_optimization_step(state, inner_loss_fn, inner_optimizer)
    return state

(inner_model is now just a pytree of parameters, renamed inner_params, and accordingly inner_params is what must be passed to optax.apply_updates, rather than state.model, which no longer exists.) I don't get any errata on the newest versions of the specified libraries.

Hope this helps! Perhaps you can backtrack from here to the version that did give you errata, and figure out what went wrong along the way :)

johannahaffner avatar Apr 13 '25 18:04 johannahaffner

Hey,

I tried it a new environment (btw I am running it on a Mac but with the cpu version of jax). I still get the error. I changed the solver to Newton but the error persists.

Seems to be from within JAX..

File ~/miniforge3/envs/metatopia_new/lib/python3.11/site-packages/lineax/_solve.py:114, in _linear_solve_impl(_, state, vector, options, solver, throw, check_closure)
    108 result = RESULTS.where(
    109     (result == RESULTS.singular) & has_nonfinite_input,
    110     RESULTS.nonfinite_input,
    111     result,
    112 )
    113 if throw:
--> 114     solution, result, stats = result.error_if(
    115         (solution, result, stats),
    116         result != RESULTS.successful,
    117     )
    118 return solution, result, stats

    [... skipping hidden 15 frame]

File ~/miniforge3/envs/metatopia_new/lib/python3.11/contextlib.py:144, in _GeneratorContextManager.__exit__(self, typ, value, traceback)
    142 if typ is None:
    143     try:
--> 144         next(self.gen)
    145     except StopIteration:
    146         return False

    [... skipping hidden 3 frame]

File ~/miniforge3/envs/metatopia_new/lib/python3.11/site-packages/jax/_src/core.py:1305, in _why_alive(ignore_ids, x)
   1292 parent = parents(child)[0]  # just pick one parent
   1294 # For namespaces (like modules and class instances) and closures, the
   1295 # references may form a simple chain: e.g. instance refers to its own
   1296 # __dict__ which refers to child, or function refers to its __closure__
   (...)   1302 #  https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599
   1303 # To prevent this collapsing behavior, just comment out this code block.
   1304 if (isinstance(parent, dict) and
-> 1305     getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]):
   1306   parent = parents(parent)[0]
   1307 elif type(parent) is types.CellType:

IndexError: list index out of range

SNMS95 avatar Apr 13 '25 19:04 SNMS95

Unfortunately I cannot reproduce the error with the last example you posted. With that, I only get an AttributeError that is easily fixed.

johannahaffner avatar Apr 13 '25 19:04 johannahaffner

I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: https://github.com/jax-ml/jax/issues/26560

itk22 avatar Apr 14 '25 09:04 itk22

I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: jax-ml/jax#26560

Which one did you try, the first or the second example? If we can figure out what the differences between your guys' and my setup is, then this would be a step forward! I used the somewhat reduced example @SNMS95 posted yesterday, not the longer one from three days ago.

jax 0.5.3, equinox 0.12.1, optimistix 0.0.10 (on current main). MacBook with M1 processor.

johannahaffner avatar Apr 14 '25 11:04 johannahaffner

I tried the code as well and managed to reproduce the error. I think it might be related to the following issue in JAX: jax-ml/jax#26560

Which one did you try, the first or the second example? If we can figure out what the differences between your guys' and my setup is, then this would be a step forward! I used the somewhat reduced example @SNMS95 posted yesterday, not the longer one from three days ago.

jax 0.5.3, equinox 0.12.1, optimistix 0.0.10 (on current main). MacBook with M1 processor.

I tried it with the following settings: Conda environment with pip installed versions of: python=3.11 jax=0.5.3 equinox=0.12.1 optimistix=0.0.10 on Mac with M3 processor

SNMS95 avatar Apr 14 '25 15:04 SNMS95

I can reproduce this with Python 3.11.9, but not with Python 3.13.1. I could reduce the MWE to this:

import jax
import optimistix as optx


jax.config.update("jax_check_tracer_leaks", True)

def match(y, args):  # Has trivial root at 0.5
    del args
    return y - 0.5

solver = optx.Bisection(rtol=1e-3, atol=1e-3)

y0 = 0.75
optx.root_find(match, solver, y0, options=dict(lower=0, upper=1))

This most probably needs a fix over in optimistix, I can look into it tomorrow. @patrick-kidger do you have suggestions as to what changed in between these Python versions? The root finder can be swapped for another one, enabling the checks for tracer leaks is required.

johannahaffner avatar Apr 14 '25 17:04 johannahaffner

Alright, this is not an optimistix error after all. The IndexError is caused by an interaction with equinox runtime errata and Python 3.11, works on Python 3.13 (haven't checked for 3.12).

import equinox as eqx
import jax
import jax.numpy as jnp


jax.config.update("jax_check_tracer_leaks", True)

x1 = jnp.ones(3)
x2 = eqx.error_if(x1, jnp.array(True), "This should not be an IndexError")   # pred can be False or True

johannahaffner avatar Apr 14 '25 19:04 johannahaffner

Interesting! This Equinox-only MWE here is super useful.

As another preliminary result on my side, I can avoid this error by dropping back to JAX 0.4.35, so it seems that this is something that came in with JAX's 'stackless' change. I can also avoid this error by removing our JIT-wrapped around branched_error_if_impl. Probably the next step is to further simplify this MWE to avoid Equinox too.

patrick-kidger avatar Apr 14 '25 22:04 patrick-kidger

Thanks for the pointers! I could get to a JAX-only thing by removing everything but the traceback walk from branched_error_if_impl.

johannahaffner avatar Apr 15 '25 14:04 johannahaffner