jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

second-order derivatives with implicit diff

Open SNMS95 opened this issue 11 months ago • 2 comments

I was trying to use the custom_root decorator to differentiate through a solver. When I try to take the gradients, it works well. However, if I try to use jax.hessian, I get the error that "cannot use forward-mode autodiff with a custom_vjp function". When searching the JAX documents, it shows that we can use both modes of differentiation if and only if we use custom_jvp instead of custom_vjp.

I saw that internally, custom_root implements only a custom_vjp rule. Is there any way to to choose the custom_jvp rule instead ?

A minimal example is as follows:

import jax
import jax.numpy as jnp
import numpy as onp
from jaxopt import implicit_diff

def f(x, theta, X_train, y_train): # Objective function
    residual = jnp.dot(X_train, x) - y_train
    return (jnp.sum(residual ** 2) + theta * jnp.sum(x ** 2)) / 2

F = jax.grad(f, argnums=0)

@implicit_diff.custom_root(F)
def ridge_solver(init_x, theta, X_train, y_train):
    del init_x # Initialization not used in this solver
    XX = jnp.dot(X_train.T, X_train)
    Xy = jnp.dot(X_train.T, y_train)
    I = jnp.eye(X_train.shape[1]) # Identity matrix
    # Finds the ridge reg solution by solving a linear system
    return jnp.linalg.solve(XX + theta * I, Xy)

init_x = None
# Create some data.
onp.random.seed(0)
X_train = onp.random.randn(100, 10)
y_train = onp.random.randn(100)

print(jax.hessian(ridge_solver, argnums=1)(init_x, 10.0, X_train, y_train))

with the error

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

SNMS95 avatar Jul 13 '23 11:07 SNMS95

@mblondel Is this going to be implemented ?.

The reason why I need this because I have a blackbox solver from PETSc that I would like to use to solve the linear equation.

Do you know if I can use jax.lax.custom_linear_solver instead?

SNMS95 avatar Jul 18 '23 19:07 SNMS95

Is this in the plans? It would be great to have this to be able to nest optimizers easily.

benjaminvatterj avatar Feb 03 '24 17:02 benjaminvatterj