Suryanarayanan Manoj Sanu

Results 23 comments of Suryanarayanan Manoj Sanu

Thanks Patrick. I will attempt to do that in this thread (and create a MWE), and then close this so that it could be useful for others as well.

Hey @vboussange , I did, it looks something like this: Since I had to deal with sparsity on top as well, I first made a `SparseOperator` to wrap `jax.BCOO`. If...

Hi @vboussange , If you use `filter_pure_callback` from equinox, you should be able to store python objects into the initial state as well. Would be nice, if you crack this,...

I was wondering how it would be possible to get the gradients w.r.t say the preconditioner or tolerance? If we differentiate abiding by the implicit function theorem, these should be...

> You can do this by calling (and differentiating through) `solver.compute` directly. > > It's true that we don't offer this by default in Lineax. I think this is probably...

> Hey [@SNMS95](https://github.com/SNMS95), we ([@bhorowitz](https://github.com/bhorowitz), [@gdalle](https://github.com/gdalle), myself...) are also quite interested in porting AMGX to JAX. Here are some leads: > > 1. Using the [`ffi`](https://docs.jax.dev/en/latest/ffi.html) feature of JAX to...

@johannahaffner Is this possible? Is `jax.ffi` needed?

Hey I did. However, in the original paper, the algorithm is bilevel and the slow parameters are the "main" parameters. In the example, the slow parameters are only used during...

Hi @patrick-kidger This is an even more minimal example. But this errors out. ``` import jax import jax.flatten_util import jax.numpy as jnp import equinox as eqx from collections import namedtuple...

Hi Patrick, I upgraded everything and still the issue is there. I managed to reduce the example further! Will this be okay? ``` import equinox as eqx import optimistix as...