Problem with usage of Linear Operators
Hello,
Ive encountered a problem when trying to create a Gauss-Newton like matrix in terms of lineax.FunctionLinearOperators. My test_func takes and returns a pytree. The final line where I try to compute j2 fails with an assertion error.
Am I doing something wrong or is what im trying to do just not possible?
import lineax as lx
import equinox as eqx
import jax
class my_value(eqx.Module):
x: None
y: None
def test_func(x):
return jax.tree.map(lambda x: jnp.abs(x)**2, x)
x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))
jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
j2 = jac.transpose() @ jac
site-packages/jax/_src/lax/lax.py:7400, in _select_transpose_rule(t, which, *cases) 7399 def _select_transpose_rule(t, which, *cases): ->7400 assert not ad.is_undefined_primal(which) 7401 if type(t) is ad_util.Zero: 7402 return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None 7403 for c in cases] AssertionError:
Hi,
I think this is a reasonable expectation, given that we transpose with
https://github.com/patrick-kidger/lineax/blob/6bf1a0e17c6352b96bdf52e72772fac63eb641df/lineax/_operator.py#L700
and
jax.linear_transpose(jax.jacobian(test_func), x)
works. I found a couple other ways to get the thing you want:
_, lin_fn = jax.linearize(test_func, x)
jac = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: x))
jac.transpose() @ jac
or
jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
lx.materialise(jac).transpose() @ jac
or
def test_func(x, args): # This now requires a second argument
return jax.tree.map(lambda x: jnp.abs(x)**2, x)
x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))
jac = lx.JacobianLinearOperator(test_func, x, None)
jac.transpose() @ jac
For the latter, you can also wrap it in a lx.linearise if you're planning on re-using the operator
jac = lx.linearise(lx.JacobianLinearOperator(test_func, x, None))
jac.transpose() @ jac
I haven't looked at the implementations beyond the transposing method, and I would need to do that to understand why the transpose of a FunctionLinearOperator behaves the way it does (and if that could/should be changed).
I hope you have some options until I find a quiet minute!
Here's a MWE without Lineax, using just JAX:
import jax
import jax.numpy as jnp
def test_func(x):
return jnp.abs(x)**2
x = jnp.ones((3, 3))
jac_func = jax.jacrev(test_func)
jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32))(x)
Despite the fact that jacrev(test_func) is linear, it seems that JAX doesn't know how to transpose it.
Interestingly, this can be fixed by explicitly linearising it:
import jax
import jax.numpy as jnp
def test_func(x):
return jnp.abs(x)**2
x = jnp.ones((3, 3))
jac_func = jax.jacrev(test_func)
_, jac_func = jax.linearize(jac_func, jnp.ones(3))
jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32))(x)
Here the choice of linearisation point (jnp.ones(3)) is arbitrary, since the function being linearised is already linear.
At the JAX level, this reflects what I think is a known limitation of JAX -- its treatment of 'is this function linear?' is only really tested on the output of jax.linearize, and it has false negatives for user-constructed linear functions. Explicitly linearising them tends to fix this.
Returning to Lineax, I think @johannahaffner's suggestion of using lx.linearise on the linear operator is the correct corresponding solution.
Ah interesting, jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32)) gives reasonable-looking output, but it seems that there is a check that is only ran once this thing is called.
Thanks for spotting this, @patrick-kidger!