equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Speeding up Newton Update

Open BabaYara opened this issue 4 months ago • 2 comments

Hello Patrick,

Thanks a lot for all the working you are putting in here. I have been working on a problem that I can solve with small scale neural network but need to use stochastic newton updates. So the fundamental problem I am facing now is trying to speed up the newton part of the code as much as possible.

The current newton update does something like this:


def step_h(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    a, s = eqx.partition(mlp, eqx.is_inexact_array)
    flat_a, unflat_a = flatten_util.ravel_pytree(a)
    h = jax.hessian(loss_h)(flat_a, s, xs, ys, unflat_a)
    g_flat, unflat = flatten_util.ravel_pytree(grads)
    updates = unflat(-1 * jnp.linalg.pinv(h) @ g_flat)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals


I know this is highly inefficient and was hoping to get some advice on optimizing the code. Here is the working example this is based off.


import jax
import equinox as eqx
from jax import numpy as jnp
import matplotlib.pyplot as plt
from jax import flatten_util

key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)
data = jnp.concatenate((jax.random.normal(subkey1, shape=(100, 2)) * 0.1 - 1, jax.random.normal(subkey2, shape=(100, 2)) * 0.1 + 1), axis=0)
labels = jnp.array([0] * 100 + [1] * 100)
plt.scatter(data[:,0], data[:,1])

class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(2, 10, key=key1),
            eqx.nn.Linear(10, 12, key=key2),
            eqx.nn.Linear(12, 2, key=key3),

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def loss_fn(model, ins, ytrue):
    pred_y = eqx.filter_vmap(model)(ins)
    return cross_entropy(ytrue, pred_y)

def cross_entropy(y, pred_y):
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

def compute_accuracy(m, x, y):
    pred_y = eqx.filter_vmap(m)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

key = jax.random.PRNGKey(42)

def step(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    updates = jax.tree_map(lambda g: -0.1 * g, grads)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals

epochs = 50
batch_size = 100
grad_loss = []
grad_acc = []

key, subkey = jax.random.split(key, 2)
model = MLP(subkey)

for e in range(epochs):
    if e % 20 == 0:
        print(e, "/", epochs)
    key, subkey = jax.random.split(key, 2)
    inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
    inputs = data[inds]
    ls = labels[inds]
    model, loss = step(model, inputs, ls)
    grad_acc.append(compute_accuracy(model, data, labels))

def loss_h(arrs, static, ins, ytrue, uf):
    arrs = uf(arrs)
    model = eqx.combine(arrs, static)
    pred_y = eqx.filter_vmap(model)(ins)
    return cross_entropy(ytrue, pred_y)

def step_h(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    a, s = eqx.partition(mlp, eqx.is_inexact_array)
    flat_a, unflat_a = flatten_util.ravel_pytree(a)
    h = jax.hessian(loss_h)(flat_a, s, xs, ys, unflat_a)
    g_flat, unflat = flatten_util.ravel_pytree(grads)
    updates = unflat(-1 * jnp.linalg.pinv(h) @ g_flat)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals

key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key, 2)
model_h = MLP(subkey)

epochs = 50
batch_size = 100
h_loss = []
h_acc = []

for e in range(epochs):
    if e % 20 == 0:
        print(e, "/", epochs)
    key, subkey = jax.random.split(key, 2)
    inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
    inputs = data[inds]
    ls = labels[inds]
    model_h, loss = step_h(model_h, inputs, ls)
    h_loss.append(loss_fn(model_h, data, labels))
    h_acc.append(compute_accuracy(model_h, data, labels))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.plot(h_loss, label="Newton")
ax2.plot(h_acc, label="Newton")

ax1.plot(grad_loss, label="Grad")
ax2.plot(grad_acc, label="Grad")



BabaYara avatar Feb 23 '24 21:02 BabaYara