lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Problem with usage of Linear Operators

Open matillda123 opened this issue 4 months ago • 3 comments

Hello,

Ive encountered a problem when trying to create a Gauss-Newton like matrix in terms of lineax.FunctionLinearOperators. My test_func takes and returns a pytree. The final line where I try to compute j2 fails with an assertion error. Am I doing something wrong or is what im trying to do just not possible?

import lineax as lx
import equinox as eqx
import jax


class my_value(eqx.Module):
    x: None
    y: None


def test_func(x):
    return jax.tree.map(lambda x: jnp.abs(x)**2, x)



x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))
jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
j2 = jac.transpose() @ jac

site-packages/jax/_src/lax/lax.py:7400, in _select_transpose_rule(t, which, *cases) 7399 def _select_transpose_rule(t, which, *cases): ->7400 assert not ad.is_undefined_primal(which) 7401 if type(t) is ad_util.Zero: 7402 return [None] + [ad_util.Zero(c.aval) if ad.is_undefined_primal(c) else None 7403 for c in cases] AssertionError:

matillda123 avatar Aug 28 '25 11:08 matillda123

Hi,

I think this is a reasonable expectation, given that we transpose with

https://github.com/patrick-kidger/lineax/blob/6bf1a0e17c6352b96bdf52e72772fac63eb641df/lineax/_operator.py#L700

and

jax.linear_transpose(jax.jacobian(test_func), x)

works. I found a couple other ways to get the thing you want:

_, lin_fn = jax.linearize(test_func, x) 
jac = lx.FunctionLinearOperator(lin_fn, jax.eval_shape(lambda: x))
jac.transpose() @ jac

or

jac = lx.FunctionLinearOperator(jax.jacobian(test_func), jax.eval_shape(lambda: x))
lx.materialise(jac).transpose() @ jac

or

def test_func(x, args):  # This now requires a second argument
    return jax.tree.map(lambda x: jnp.abs(x)**2, x)

x = my_value(jnp.ones((3,1)), jnp.ones((3,1)))

jac = lx.JacobianLinearOperator(test_func, x, None)
jac.transpose() @ jac

For the latter, you can also wrap it in a lx.linearise if you're planning on re-using the operator

jac = lx.linearise(lx.JacobianLinearOperator(test_func, x, None))
jac.transpose() @ jac

I haven't looked at the implementations beyond the transposing method, and I would need to do that to understand why the transpose of a FunctionLinearOperator behaves the way it does (and if that could/should be changed).

I hope you have some options until I find a quiet minute!

johannahaffner avatar Aug 28 '25 21:08 johannahaffner

Here's a MWE without Lineax, using just JAX:

import jax
import jax.numpy as jnp

def test_func(x):
    return jnp.abs(x)**2

x = jnp.ones((3, 3))
jac_func = jax.jacrev(test_func)
jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32))(x)

Despite the fact that jacrev(test_func) is linear, it seems that JAX doesn't know how to transpose it.

Interestingly, this can be fixed by explicitly linearising it:

import jax
import jax.numpy as jnp

def test_func(x):
    return jnp.abs(x)**2

x = jnp.ones((3, 3))
jac_func = jax.jacrev(test_func)
_, jac_func = jax.linearize(jac_func, jnp.ones(3))
jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32))(x)

Here the choice of linearisation point (jnp.ones(3)) is arbitrary, since the function being linearised is already linear.

At the JAX level, this reflects what I think is a known limitation of JAX -- its treatment of 'is this function linear?' is only really tested on the output of jax.linearize, and it has false negatives for user-constructed linear functions. Explicitly linearising them tends to fix this.


Returning to Lineax, I think @johannahaffner's suggestion of using lx.linearise on the linear operator is the correct corresponding solution.

patrick-kidger avatar Aug 31 '25 11:08 patrick-kidger

Ah interesting, jax.linear_transpose(jac_func, jax.ShapeDtypeStruct((3,), jnp.float32)) gives reasonable-looking output, but it seems that there is a check that is only ran once this thing is called.

Thanks for spotting this, @patrick-kidger!

johannahaffner avatar Aug 31 '25 11:08 johannahaffner