lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Creating JacobianLinearOperator fails with confusing error

Open johannahaffner opened this issue 5 months ago • 2 comments

I'm getting a confusing error when trying to create a JacobianLinearOperator.

  • lx.JacobianLinearOperator(f, x0) causes TypeError: f() takes 1 positional argument but 2 were given,
  • lx.JacobianLinearOperator(f) causes TypeError: JacobianLinearOperator.__init__() missing 1 required positional argument: 'x'.

Based on the documentation, I assumed that the first one would be correct.

Here is an MWE:

import jax
import jax.numpy as jnp
import lineax as lx

def f(x):
    return jnp.sum(x)

x0 = jnp.array([0.1, 0.2])
x1 = jnp.array([0.2, 0.3])

primals, tangents = jax.jvp(f, (x0,), (x1,))

lx.JacobianLinearOperator(f, x0)

I'll poke around a bit to see if I can figure out why this happens.

johannahaffner avatar Sep 23 '24 15:09 johannahaffner