lineax
lineax copied to clipboard
How to use PyTreeLinearOperator
Hi lineax community,
I came from Patrick's comment (https://github.com/google/jax/discussions/17203#discussioncomment-6814678). I believe PyTreeLinearOperator
will do the job but am struggling to use it correctly. In the example below, I want to solve a Newton system $Jx = -F$, where the Jacobian matrix $F$ is a PyTree. How can I use PyTreeLinearOperator
correctly?
import jax
from jax import jit, vmap, lax, jacfwd, jacrev, grad, vjp, jvp, random
import jax.numpy as jnp
from jax.config import config
from jax.tree_util import tree_structure, tree_flatten, tree_unflatten
import equinox as eqx
from jaxtyping import Float, Array, Bool
import lineax as lx
class Parameter(eqx.Module):
alpha: Float[Array, ""]
beta: Float[Array, ""]
def __init__(self, alpha, beta):
self.alpha = alpha
self.beta = beta
class State(eqx.Module):
x_0: Array
x_1: Array
def __init__(self, x_guess):
self.x_0 = x_guess[0]
self.x_1 = x_guess[1]
class Model(eqx.Module):
parameters: eqx.Module
def __init__(self, parameters):
self.parameters = parameters
def residual(self, state):
F_0 = state.x_0**2 + state.x_1**2 - self.parameters.alpha
F_1 = self.parameters.beta*state.x_0**3 - state.x_1
return jnp.array([F_0, F_1])
alpha = 4.0
beta = 1.0
x_test = jnp.array([2.0, 3.0])
parameter = Parameter(alpha, beta)
state = State(x_test)
model = Model(parameter)
Jacobian_JAX_class = jacfwd(model.residual, argnums=0, has_aux=False)
F = model.residual(state)
J = Jacobian_JAX_class(state)
J = lx.PyTreeLinearOperator(J, jax.eval_shape(lambda: state))
# lx.linear_solve(J, F) # failed