jaxopt copied to clipboard
second-order derivatives with implicit diff
I was trying to use the custom_root
decorator to differentiate through a solver. When I try to take the gradients, it works well. However, if I try to use jax.hessian
, I get the error that "cannot use forward-mode autodiff with a custom_vjp function". When searching the JAX documents, it shows that we can use both modes of differentiation if and only if we use custom_jvp
instead of custom_vjp
I saw that internally, custom_root
implements only a custom_vjp
rule. Is there any way to to choose the custom_jvp
rule instead ?
A minimal example is as follows:
import jax
import jax.numpy as jnp
import numpy as onp
from jaxopt import implicit_diff
def f(x, theta, X_train, y_train): # Objective function
residual = jnp.dot(X_train, x) - y_train
return (jnp.sum(residual ** 2) + theta * jnp.sum(x ** 2)) / 2
F = jax.grad(f, argnums=0)
def ridge_solver(init_x, theta, X_train, y_train):
del init_x # Initialization not used in this solver
XX = jnp.dot(X_train.T, X_train)
Xy = jnp.dot(X_train.T, y_train)
I = jnp.eye(X_train.shape[1]) # Identity matrix
# Finds the ridge reg solution by solving a linear system
return jnp.linalg.solve(XX + theta * I, Xy)
init_x = None
# Create some data.
X_train = onp.random.randn(100, 10)
y_train = onp.random.randn(100)
print(jax.hessian(ridge_solver, argnums=1)(init_x, 10.0, X_train, y_train))
with the error
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
@mblondel Is this going to be implemented ?.
The reason why I need this because I have a blackbox solver from PETSc that I would like to use to solve the linear equation.
Do you know if I can use jax.lax.custom_linear_solver
Is this in the plans? It would be great to have this to be able to nest optimizers easily.