equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Speeding up Newton Update

Open BabaYara opened this issue 4 months ago • 2 comments

Hello Patrick,

Thanks a lot for all the working you are putting in here. I have been working on a problem that I can solve with small scale neural network but need to use stochastic newton updates. So the fundamental problem I am facing now is trying to speed up the newton part of the code as much as possible.

The current newton update does something like this:

`

@eqx.filter_jit
def step_h(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    a, s = eqx.partition(mlp, eqx.is_inexact_array)
    flat_a, unflat_a = flatten_util.ravel_pytree(a)
    h = jax.hessian(loss_h)(flat_a, s, xs, ys, unflat_a)
    g_flat, unflat = flatten_util.ravel_pytree(grads)
    updates = unflat(-1 * jnp.linalg.pinv(h) @ g_flat)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals

`

I know this is highly inefficient and was hoping to get some advice on optimizing the code. Here is the working example this is based off.

`

import jax
import equinox as eqx
from jax import numpy as jnp
import matplotlib.pyplot as plt
from jax import flatten_util

key = jax.random.PRNGKey(42)
key, subkey1, subkey2 = jax.random.split(key, 3)
data = jnp.concatenate((jax.random.normal(subkey1, shape=(100, 2)) * 0.1 - 1, jax.random.normal(subkey2, shape=(100, 2)) * 0.1 + 1), axis=0)
labels = jnp.array([0] * 100 + [1] * 100)
plt.scatter(data[:,0], data[:,1])
plt.show()

class MLP(eqx.Module):
    layers: list

    def __init__(self, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(2, 10, key=key1),
            jax.nn.relu,
            eqx.nn.Linear(10, 12, key=key2),
            jax.nn.relu,
            eqx.nn.Linear(12, 2, key=key3),
            jax.nn.log_softmax
        ]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def loss_fn(model, ins, ytrue):
    pred_y = eqx.filter_vmap(model)(ins)
    return cross_entropy(ytrue, pred_y)

def cross_entropy(y, pred_y):
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)

@eqx.filter_jit
def compute_accuracy(m, x, y):
    pred_y = eqx.filter_vmap(m)(x)
    pred_y = jnp.argmax(pred_y, axis=1)
    return jnp.mean(y == pred_y)

key = jax.random.PRNGKey(42)

@eqx.filter_jit
def step(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    updates = jax.tree_map(lambda g: -0.1 * g, grads)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals

epochs = 50
batch_size = 100
grad_loss = []
grad_acc = []

key, subkey = jax.random.split(key, 2)
model = MLP(subkey)

for e in range(epochs):
    if e % 20 == 0:
        print(e, "/", epochs)
    
    key, subkey = jax.random.split(key, 2)
    inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
    inputs = data[inds]
    ls = labels[inds]
    
    model, loss = step(model, inputs, ls)
    grad_loss.append(loss)
    grad_acc.append(compute_accuracy(model, data, labels))

@eqx.filter_jit
def loss_h(arrs, static, ins, ytrue, uf):
    arrs = uf(arrs)
    model = eqx.combine(arrs, static)
    pred_y = eqx.filter_vmap(model)(ins)
    return cross_entropy(ytrue, pred_y)

@eqx.filter_jit
def step_h(mlp, xs, ys):
    vals, grads = eqx.filter_value_and_grad(loss_fn)(mlp, xs, ys)
    a, s = eqx.partition(mlp, eqx.is_inexact_array)
    flat_a, unflat_a = flatten_util.ravel_pytree(a)
    h = jax.hessian(loss_h)(flat_a, s, xs, ys, unflat_a)
    g_flat, unflat = flatten_util.ravel_pytree(grads)
    updates = unflat(-1 * jnp.linalg.pinv(h) @ g_flat)
    mlp = eqx.apply_updates(mlp, updates)
    return mlp, vals

key = jax.random.PRNGKey(1)
key, subkey = jax.random.split(key, 2)
model_h = MLP(subkey)

epochs = 50
batch_size = 100
h_loss = []
h_acc = []

for e in range(epochs):
    if e % 20 == 0:
        print(e, "/", epochs)
    
    key, subkey = jax.random.split(key, 2)
    inds = jax.random.randint(subkey, minval=0, maxval=len(data), shape=(batch_size,))
    inputs = data[inds]
    ls = labels[inds]
    
    model_h, loss = step_h(model_h, inputs, ls)
    h_loss.append(loss_fn(model_h, data, labels))
    h_acc.append(compute_accuracy(model_h, data, labels))

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

ax1.plot(h_loss, label="Newton")
ax2.plot(h_acc, label="Newton")

ax1.plot(grad_loss, label="Grad")
ax2.plot(grad_acc, label="Grad")

plt.legend()
plt.show()

`

BabaYara avatar Feb 23 '24 21:02 BabaYara