torchopt icon indicating copy to clipboard operation
torchopt copied to clipboard

[Feature Request] Support for functorch transforms

Open marvinfriede opened this issue 1 year ago • 0 comments

Required prerequisites

  • [X] I have searched the Issue Tracker and Discussions that this hasn't already been reported. (+1 or comment there if it has.)
  • [X] Consider asking first in a Discussion.

Motivation

I am interested in Jacobians and Hessians from implicitly differentiated root finding problems. This is something that regularly comes up in scientific computing. With jax, this is already possible out of the box using function transforms (e.g., jacrev). Is this something you plan to support in torchopt, too?

Solution

I already tried, but apparently, there are some pieces of code that prevent this:

  • missing setup_context for vmap rule in ImplicitMetaGradient (very easy to adapt)
  • .item() in _vdot_real_kernel
  • make_rmatvec in normal_cg
  • conditionals in _cg_solve
  • tree operations in _cg_solve

Alternatives

The jaxopt version
import jax
import jax.numpy as jnp
from jaxopt.implicit_diff import custom_root
from jaxopt import Bisection

jax.config.update("jax_platform_name", "cpu")


def F(x, factor):
  return factor * x ** 3 - x - 2


def bisection_root_solver(init_x, factor):
  bisec = Bisection(optimality_fun=F, lower=1, upper=2)
  return bisec.run(factor=factor).params


@custom_root(F)
def custom_root_solver(init_x, factor):
    """Root solver using gradient descent."""
    maxiter = 100
    lr = 1e-1

    x = init_x
    for _ in range(maxiter):
        grad = F(x, factor)
        x = x - lr * grad

    return x


x_init = jnp.array(3.0)
fac = jnp.array(2.0)

print(custom_root_solver(x_init, fac))
print(bisection_root_solver(x_init, fac))

print(jax.grad(custom_root_solver, argnums=1)(x_init, fac))
print(jax.grad(bisection_root_solver, argnums=1)(x_init, fac))

custom_jac_fcn = jax.jacrev(custom_root_solver, argnums=1)
print(jax.jacrev(custom_jac_fcn, argnums=1)(x_init, fac))
bisection_jac_fcn = jax.jacrev(bisection_root_solver, argnums=1)
print(jax.jacrev(bisection_jac_fcn, argnums=1)(x_init, fac))

Additional context

No response

marvinfriede avatar May 27 '24 15:05 marvinfriede