Suryanarayanan Manoj Sanu
Suryanarayanan Manoj Sanu
Thanks Patrick. I will attempt to do that in this thread (and create a MWE), and then close this so that it could be useful for others as well.
Hey @vboussange , I did, it looks something like this: Since I had to deal with sparsity on top as well, I first made a `SparseOperator` to wrap `jax.BCOO`. If...
Hi @vboussange , If you use `filter_pure_callback` from equinox, you should be able to store python objects into the initial state as well. Would be nice, if you crack this,...
I was wondering how it would be possible to get the gradients w.r.t say the preconditioner or tolerance? If we differentiate abiding by the implicit function theorem, these should be...
> You can do this by calling (and differentiating through) `solver.compute` directly. > > It's true that we don't offer this by default in Lineax. I think this is probably...
> Hey [@SNMS95](https://github.com/SNMS95), we ([@bhorowitz](https://github.com/bhorowitz), [@gdalle](https://github.com/gdalle), myself...) are also quite interested in porting AMGX to JAX. Here are some leads: > > 1. Using the [`ffi`](https://docs.jax.dev/en/latest/ffi.html) feature of JAX to...
@johannahaffner Is this possible? Is `jax.ffi` needed?
Hey I did. However, in the original paper, the algorithm is bilevel and the slow parameters are the "main" parameters. In the example, the slow parameters are only used during...
Hi @patrick-kidger This is an even more minimal example. But this errors out. ``` import jax import jax.flatten_util import jax.numpy as jnp import equinox as eqx from collections import namedtuple...
Hi Patrick, I upgraded everything and still the issue is there. I managed to reduce the example further! Will this be okay? ``` import equinox as eqx import optimistix as...