General Questions
Hey,
can lineax be used to get the actual inverse of a matrix? Or is one limited to solving a particular equation? I couldn't find that in the documentation. And is the general recommendation that lineax is to be preferred over the "equivalent" functions from jax (because of speed, bugfixes and so on), or are there cases where jax should be used?
Thanks for the help. :)
Yup, you can get the inverse by solving against the identity matrix:
A = ...
A_inverse = jax.vmap(lambda b: lx.linear_solve(A, b))(jnp.eye(...))
Incidentally this trick is also how pretty much every inverse-finding routine in any numerical library works under the hood as well.
As for comparing against JAX builtins, see this FAQ entry. The speed differences to JAX usually aren't that meaningful (although you can find cases where it's larger) and the main advantages are usually the other points highlighted.