optimistix
optimistix copied to clipboard
Using `optimistix` with an `equinox` model
Hi everyone, thanks for the great library and apologies in advance for this basic question.
I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix
together with an equinox
model. However, I haven't been able to make the two work together.
Here is a minimal snippet which fails:
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
jax.config.update("jax_enable_x64", True)
X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))
@jax.vmap
def function(x):
return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3
y = function(X).reshape(-1, 1)
model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))
static, params = eqx.partition(model, eqx.is_inexact_array)
def loss_fn(params, static, X, y):
model = eqx.combine(params, static)
return jnp.sum((jax.vmap(model)(X) - y)**2)
solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)
I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>
.
What am I doing wrong? Thank you in advance.
You have static
and params
the wrong way around.
(Caveat: I've not tried running the code, this was just what jumped out at me.)
Indeed, that was the issue! Thanks a lot for the reply and for the library!