optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Using `optimistix` with an `equinox` model

Open frostedoyster opened this issue 9 months ago • 2 comments

Hi everyone, thanks for the great library and apologies in advance for this basic question. I'm trying to find the true minimum of a small neural network, and I thought of using a solver from optimistix together with an equinox model. However, I haven't been able to make the two work together.

Here is a minimal snippet which fails:

import jax 
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx

jax.config.update("jax_enable_x64", True)


X = jax.random.normal(jax.random.PRNGKey(0), (2000, 8))

@jax.vmap
def function(x):
    return x[0] + x[1]**2 + jnp.cos(x[2]) + jnp.sin(x[3]) + x[4]*x[5] + (x[6]*x[7])**3

y = function(X).reshape(-1, 1)

model = eqx.nn.MLP(in_size=8, out_size=1, width_size=4, depth=2, activation=jax.nn.silu, key=jax.random.PRNGKey(0))

static, params = eqx.partition(model, eqx.is_inexact_array)

def loss_fn(params, static, X, y):
    model = eqx.combine(params, static)
    return jnp.sum((jax.vmap(model)(X) - y)**2)

solver = optx.Newton(rtol=1e-5, atol=1e-5)
sol = optx.minimise(loss_fn, solver, params)

I'm getting TypeError: Cannot determine dtype of <PjitFunction of <function silu at 0x742fde959300>>.

What am I doing wrong? Thank you in advance.

frostedoyster avatar May 07 '24 17:05 frostedoyster

You have static and params the wrong way around.

(Caveat: I've not tried running the code, this was just what jumped out at me.)

patrick-kidger avatar May 07 '24 17:05 patrick-kidger

Indeed, that was the issue! Thanks a lot for the reply and for the library!

frostedoyster avatar May 08 '24 12:05 frostedoyster