Patrick Kidger
Patrick Kidger
This one should be safe from JAX initialisation, I believe. :)
This sounds pretty weird :) But I'm glad you got to the bottom of it!
I think you want to wrap your `options` in a `lax.stop_gradient`. As Johanna notes then `options` could in principle be anything, so we don't do this automatically -- we don't...
Hmm probably `stop_gradient` should just be applied to the arrays then, via `equinox.partition` and `equinox.combine`. Maybe we should just always apply such a stop-gradient to preconditioners? Mathematically the output should...
You can do this by calling (and differentiating through) `solver.compute` directly. It's true that we don't offer this by default in Lineax. I think this is probably fairly niche, though?
Awesome! I'm glad that it works. > However, is there any specific reason why lax.while was used to make CG instead of eqxi.while_loop ? The latter is more versatile right?...
For this one I think let's aim to get this fixed upstream in JAX itself. I think it'd be really hard for us to work around this issue here in...
Sounds good! This has been on my wishlist a long time, but I've never sat down to sort it out. FWIW I believe XLA already does some level of algebraic...
Oh, that is a pretty awesome speedup. Sounds like we should find a way to get this in :) As for why the `einsum`, just that I find it more...
Oh nice! I think your benchmark script looks correct to me. Indeed I suspect XLA is pattern-matching on attention, then.