lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Linear solvers from external libraries

Open SNMS95 opened this issue 6 months ago • 4 comments

Hi Patrick,

I wanted to experiment with some linear solvers from PETSc and I managed to do it by subclassing lx.AbstractLinearSolver and using jax.pure_callback. This ensured that all of Jax's transformations work super well.

However, this will place the computation on the host even if the external library can support GPU. Do you know if it would be possible to bypass this restriction somehow?

SNMS95 avatar Jun 17 '25 09:06 SNMS95

@johnviljoen this might be relevant just a little further down the line

johannahaffner avatar Jun 17 '25 09:06 johannahaffner

@johannahaffner Is this possible? Is jax.ffi needed?

SNMS95 avatar Jun 17 '25 12:06 SNMS95

Hi! So Johanna and I are working on some fancy large scale nonlinear constrained optimizers in optimistix and as part of that I have integrated nvidia cuDSS into jax for the purpose of large, vectorizable, sparse, refactorizable direct linear solves on GPU. (This does require FFI there is no way around it).

I will be releasing this in a month or two but it currently works and depending on your needs I might be able to share it with you, as this would also help me battle test it! Send me an email [email protected] if you're interested! (The hope is to also integrate it into lineax but I haven't gotten there yet)

johnviljoen avatar Jun 17 '25 12:06 johnviljoen

I have no further information beyond what @johannahaffner and @johnviljoen have already shared, hopefully they can help! :)

patrick-kidger avatar Jun 17 '25 13:06 patrick-kidger