optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Not working: Minimizing a loss function using Levenberg–Marquardt algorithm

Open aitzaz-web opened this issue 1 year ago • 3 comments

Hi,

I am attempting to train a neural network to approximate a 3D function in which I trying to minimize the loss function using the Levenberg-Marquardt (LM) optimizer in Optimistix. However, the parameters do not update beyond the first optimization step. The loss remains unchanged after the first iteration, indicating that the optimizer may not be properly updating the parameters. You can find my code below. Not sure where the error lies.

import jax.numpy as jnp import jax.random as random import optimistix import numpy as np import matplotlib.pyplot as plt from flax import linen as nn from sklearn.preprocessing import MinMaxScaler

Generate Synthetic 3D Data

key = random.PRNGKey(42) X_data = jnp.linspace(-5, 5, 20) Y_data = jnp.linspace(-5, 5, 20) Z_data = jnp.linspace(-5, 5, 20)

X_mesh, Y_mesh, Z_mesh = jnp.meshgrid(X_data, Y_data, Z_data, indexing='ij') X_flat, Y_flat, Z_flat = X_mesh.flatten(), Y_mesh.flatten(), Z_mesh.flatten()

def true_function(X, Y, Z): return 2.0 * jnp.sin(1.5 * X) + 1.0 * jnp.cos(2.0 * Y) + 0.5 * Z

F_values = true_function(X_flat, Y_flat, Z_flat) + 0.05 * random.normal(key, (len(X_flat),)) X_train = jnp.stack([X_flat, Y_flat, Z_flat], axis=1) F_train = jnp.array(F_values)

scaler = MinMaxScaler() X_train = jnp.array(scaler.fit_transform(np.array(X_train)), dtype=jnp.float32)

Define Neural Network Model

class TanhMLP(nn.Module): layers: list

@nn.compact
def __call__(self, x):
    for units in self.layers[:-1]:
        x = nn.Dense(units)(x)
        x = jnp.tanh(x)
    return nn.Dense(self.layers[-1])(x)

key = random.PRNGKey(42) model = TanhMLP(layers=[3, 128, 128, 64, 1]) params = model.init(key, jnp.ones((1, 3), dtype=jnp.float32))

Define Residual Function

def residuals(params, args): X, true_F = args predicted_F = model.apply(params, X).flatten() return [predicted_F - true_F]

Optimize Using Levenberg-Marquardt

solver = optimistix.LevenbergMarquardt(rtol=1e-4, atol=1e-4, verbose=frozenset(["loss", "step"])) solution = optimistix.least_squares(residuals, solver, params, args=(X_train, F_train), max_steps=2000) optimized_params = solution.value

It would be great if I could get some help here!

aitzaz-web avatar Mar 17 '25 03:03 aitzaz-web

The loss remains unchanged after the first iteration, indicating that the optimizer may not be properly updating the parameters. You can find my code below. Not sure where the error lies.

This is based on the verbose printout?

johannahaffner avatar Mar 17 '25 09:03 johannahaffner

My hypothesis is that the first step, which is accepted by default, is too large. Because you're using a tanh, it takes you into a region of zero gradients. Since you're not getting a reduction in your objective function, the algorithm does not terminate.

Unfortunately I have no time to dig into this today! Thank you for the issue though, this is really valuable feedback. Some simple heuristics you could try in the meantime:

  • initialise with different random seeds
  • scale down your initial weights and biases, to stay further away from the saturation regions of tanh
  • use GradientDescent with a conservative learning rate for the first few steps to get to a better region (you can do that through optx.least_squares, it will automatically minimise the scalar loss)

johannahaffner avatar Mar 17 '25 09:03 johannahaffner

Thanks for your feedback! Scaling down the layers and weights seemed to get it working.

aitzaz-web avatar Mar 18 '25 19:03 aitzaz-web