Linear solvers from external libraries
Hi Patrick,
I wanted to experiment with some linear solvers from PETSc and I managed to do it by subclassing lx.AbstractLinearSolver and using jax.pure_callback. This ensured that all of Jax's transformations work super well.
However, this will place the computation on the host even if the external library can support GPU. Do you know if it would be possible to bypass this restriction somehow?
@johnviljoen this might be relevant just a little further down the line
@johannahaffner Is this possible?
Is jax.ffi needed?
Hi! So Johanna and I are working on some fancy large scale nonlinear constrained optimizers in optimistix and as part of that I have integrated nvidia cuDSS into jax for the purpose of large, vectorizable, sparse, refactorizable direct linear solves on GPU. (This does require FFI there is no way around it).
I will be releasing this in a month or two but it currently works and depending on your needs I might be able to share it with you, as this would also help me battle test it! Send me an email [email protected] if you're interested! (The hope is to also integrate it into lineax but I haven't gotten there yet)
I have no further information beyond what @johannahaffner and @johnviljoen have already shared, hopefully they can help! :)