Issue with BFGS + vmap + diagonal
I ran into a bug when using BFGS + vmap + diagonal functions. As can be seen from the minimal reproducible example, using another optimizer like NealderMead works fine. I originally ran into this issue when minimizing something like MVN(loc=..., covariance_matrix=...).log_prob(...) from distrax, but managed to reproduce it with the offending function, diagonal. I also ran into the same error when using some bijectors from tensorflow_probability (this time the offending function was diag), but couldn't seem to reproduce that outside of my codebase.
- jax version: 0.6.1
- optimistix version: 0.0.10
Reproduce
import jax
import jax.numpy as jnp
import jax.random as jr
import optimistix as optx
x = jr.normal(jr.key(0), (5, 10))
def inner_fn(y, x):
z = jnp.outer(x, y)
return z.diagonal(axis1=-1, axis2=-2).sum()
jax.vmap(inner_fn)(jr.normal(jr.key(0), (5, 10)), x)
def outer_fn(x, key, solver):
res = optx.minimise(
lambda y, _: inner_fn(y, x),
solver=solver,
y0=jr.normal(key, (10,)),
throw=False,
)
return res.value
nm = optx.NelderMead(rtol=1e-3, atol=1e-3)
bfgs = optx.BFGS(rtol=1e-3, atol=1e-3)
jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), nm)
jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), bfgs)
Error
Traceback (most recent call last):
File ".../test.py", line 14, in <module>
jax.vmap(inner_fn)(jr.normal(jr.key(0), (5, 10)), x)
File ".../test.py", line 11, in inner_fn
return z.diagonal(axis1=-1, axis2=-2).sum()
File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 1087, in meth
return getattr(self.aval, name).fun(self, *args, **kwargs)
File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 190, in _diagonal
return lax_numpy.diagonal(self, offset=offset, axis1=axis1, axis2=axis2)
File ".../.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 7660, in diagonal
return lax.platform_dependent(a, default=_default_diag, mosaic=_mosaic_diag)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError: The index of a cond with branches_platforms should be a platform_index and should never be mapped
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File ".../test.py", line 29, in <module>
print(jax.vmap(outer_fn, in_axes=(0, 0, None))(x, jr.split(jr.key(0), 5), bfgs))
File ".../test.py", line 17, in outer_fn
res = optx.minimise(
AssertionError: The index of a cond with branches_platforms should be a platform_index and should never be mapped
Thanks for the report! I can reproduce this with JAX 0.6.1 and it disappears with 0.6.0. It affects solvers that have a cond between an accepted and a rejected branch in their step (BFGS, NonlinearCG, ...).
This seems to be a (very) new thing, @nstarman can you comment on what https://github.com/patrick-kidger/quax/pull/64 does? Do we need something similar here
https://github.com/patrick-kidger/optimistix/blob/9927984fb8cbec77f9514fad7af076dce64e3993/optimistix/_misc.py#L247
or is this a bug outside of the Equinox ecosystem that we just happen to run into?
I have a hunch that this is actually a JAX bug: namely, that it is spuriously introducing a batch tracer (they do this in a few places where it made the implementation easier), but that here they are running into a case where a batch tracer is disallowed.
I think we'd need to bisect this into a MWE without Optimistix, so that we can report this upstream.
My feeling too - unless we inadvertently make the platform index an array somewhere. I have a longer train ride today and can dig into this then :)
This is indeed a JAX only error, now raised upstream. In the meantime I suggest downgrading to JAX 0.6.0 to avoid the error.