optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Issue with BFGS + vmap + diagonal

Open gorold opened this issue 7 months ago • 4 comments

I ran into a bug when using BFGS + vmap + diagonal functions. As can be seen from the minimal reproducible example, using another optimizer like NealderMead works fine. I originally ran into this issue when minimizing something like MVN(loc=..., covariance_matrix=...).log_prob(...) from distrax, but managed to reproduce it with the offending function, diagonal. I also ran into the same error when using some bijectors from tensorflow_probability (this time the offending function was diag), but couldn't seem to reproduce that outside of my codebase.

  • jax version: 0.6.1
  • optimistix version: 0.0.10

Reproduce

import jax
import jax.numpy as jnp
import jax.random as jr
import optimistix as optx


x = jr.normal(jr.key(0), (5, 10))

def inner_fn(y, x):
    z = jnp.outer(x, y)
    return z.diagonal(axis1=-1, axis2=-2).sum()


jax.vmap(inner_fn)(jr.normal(jr.key(0), (5, 10)), x)

def outer_fn(x, key, solver):
    res = optx.minimise(
        lambda y, _: inner_fn(y, x), 
        solver=solver,
        y0=jr.normal(key, (10,)),
        throw=False,
    )
    return res.value

nm = optx.NelderMead(rtol=1e-3, atol=1e-3)
bfgs = optx.BFGS(rtol=1e-3, atol=1e-3)

jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), nm)
jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), bfgs)

Error

Traceback (most recent call last):
  File ".../test.py", line 14, in <module>
    jax.vmap(inner_fn)(jr.normal(jr.key(0), (5, 10)), x)
  File ".../test.py", line 11, in inner_fn
    return z.diagonal(axis1=-1, axis2=-2).sum()
  File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 1087, in meth
    return getattr(self.aval, name).fun(self, *args, **kwargs)
  File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 190, in _diagonal
    return lax_numpy.diagonal(self, offset=offset, axis1=axis1, axis2=axis2)
  File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 7660, in diagonal
    return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError: The index of a cond with branches_platforms should be a platform_index and should never be mapped

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File ".../test.py", line 29, in <module>
    print(jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), bfgs))
  File ".../test.py", line 17, in outer_fn
    res = optx.minimise(
AssertionError: The index of a cond with branches_platforms should be a platform_index and should never be mapped

gorold avatar Jun 08 '25 08:06 gorold

Thanks for the report! I can reproduce this with JAX 0.6.1 and it disappears with 0.6.0. It affects solvers that have a cond between an accepted and a rejected branch in their step (BFGS, NonlinearCG, ...).

This seems to be a (very) new thing, @nstarman can you comment on what https://github.com/patrick-kidger/quax/pull/64 does? Do we need something similar here

https://github.com/patrick-kidger/optimistix/blob/9927984fb8cbec77f9514fad7af076dce64e3993/optimistix/_misc.py#L247

or is this a bug outside of the Equinox ecosystem that we just happen to run into?

johannahaffner avatar Jun 08 '25 10:06 johannahaffner

I have a hunch that this is actually a JAX bug: namely, that it is spuriously introducing a batch tracer (they do this in a few places where it made the implementation easier), but that here they are running into a case where a batch tracer is disallowed.

I think we'd need to bisect this into a MWE without Optimistix, so that we can report this upstream.

patrick-kidger avatar Jun 08 '25 11:06 patrick-kidger

My feeling too - unless we inadvertently make the platform index an array somewhere. I have a longer train ride today and can dig into this then :)

johannahaffner avatar Jun 08 '25 12:06 johannahaffner

This is indeed a JAX only error, now raised upstream. In the meantime I suggest downgrading to JAX 0.6.0 to avoid the error.

johannahaffner avatar Jun 09 '25 11:06 johannahaffner