lineax icon indicating copy to clipboard operation
lineax copied to clipboard

How to use PyTreeLinearOperator

Open ToshiyukiBandai opened this issue 8 months ago • 8 comments

Hi lineax community,

I came from Patrick's comment (https://github.com/google/jax/discussions/17203#discussioncomment-6814678). I believe PyTreeLinearOperator will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system $Jx = -F$, where the Jacobian matrix $F$ is a PyTree. How can I use PyTreeLinearOperator correctly?

import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
import equinox as eqx
from jaxtyping import Float, Array, Bool
import lineax as lx

class Parameter(eqx.Module):
    alpha: Float[Array, ""]
    beta: Float[Array, ""]
    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta

class State(eqx.Module):
    x_0: Array
    x_1: Array
    def __init__(self, x_guess):
        self.x_0 = x_guess[0]
        self.x_1 = x_guess[1]

class Model(eqx.Module):
    parameters: eqx.Module
    def __init__(self, parameters):
        self.parameters = parameters
    def residual(self, state):
        F_0 = state.x_0**2 + state.x_1**2 - self.parameters.alpha
        F_1 = self.parameters.beta*state.x_0**3 - state.x_1
        return jnp.array([F_0, F_1])

alpha = 4.0
beta = 1.0
x_test = jnp.array([2.0, 3.0])

parameter = Parameter(alpha, beta)
state = State(x_test)
model = Model(parameter)

Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
# lx.linear_solve(J, F) # failed

ToshiyukiBandai avatar Oct 25 '23 00:10 ToshiyukiBandai