lineax
lineax copied to clipboard
Creating JacobianLinearOperator fails with confusing error
I'm getting a confusing error when trying to create a JacobianLinearOperator
.
-
lx.JacobianLinearOperator(f, x0)
causesTypeError: f() takes 1 positional argument but 2 were given
, -
lx.JacobianLinearOperator(f)
causesTypeError: JacobianLinearOperator.__init__() missing 1 required positional argument: 'x'
.
Based on the documentation, I assumed that the first one would be correct.
Here is an MWE:
import jax
import jax.numpy as jnp
import lineax as lx
def f(x):
return jnp.sum(x)
x0 = jnp.array([0.1, 0.2])
x1 = jnp.array([0.2, 0.3])
primals, tangents = jax.jvp(f, (x0,), (x1,))
lx.JacobianLinearOperator(f, x0)
I'll poke around a bit to see if I can figure out why this happens.