Patrick Kidger
Patrick Kidger
Yup, you can get the inverse by solving against the identity matrix: ```python A = ... A_inverse = jax.vmap(lambda b: lx.linear_solve(A, b))(jnp.eye(...)) ``` Incidentally this trick is also how pretty...
Side note @johannahaffner I think we could fix this to work with `jax.vmap` by wrapping the jaxpr in a `eqxi.Static`. It's probably worth being compatible with `jax.vmap` by default; I...
Chipping in my 2 cents, I think these sound in-scope for Optimistix. :)
Hmm, bother. I think fixing the underlying issue comes under "I'd be happy to take a PR on that". We could then add a test at the same time. (Ideally...
I think that's a good spot! Tagging @lockwo and #948 here -- we're thinking of reworking BatchNorm. @lockwo WDYT?
Yup! So at least from Optimistix then you can pass `optx.root_find(..., tags=frozenset({lineax.tridiagonal_tag}))` to specify the structure of the Jacobian of the target function. Note that this is a promise that...
> I can confirm that simply adding the tag in decreases runtime by about 30% on CPU on my lineax branch with unroll=1 in the solver Awesome! > (apologies for...
Aha, @johannahaffner recalls details that I don't! Indeed this is probably achievable by adjusting the solver definitions then. (I would need to think about how.)
Interesting! So the `while_loop` is provided by the choice of adjoint method, and e.g. the default `ImplicitAdjoint` is actually just providing a `lax.while_loop`... which is exactly the same thing that...
Sorry, I wasn't sufficiently clear: `lax.fori_loop` lowers to `lax.scan`... which then itself lowers to `lax.while_loop`, just with a very simple `cond_fn` that checks the number of iterations. (Hence my comment...