Owen L
Owen L
> I can't comment on the others, but this one is expected and (still) a general limitation in JAX. Yea, I get that that error is a valid error, maybe...
> And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass! Smh why does google not use a tiny bit...
> Happy to take a PR on this one! ;) As a strict adherent to software best practices I will ~~make a new PR for this~~ roll this into my...
I get them to be a lot closer by using `UnsafeBrownianPath`, which has less overhead than VBT. Diffrax is still a bit slower with this change on my machine, but...
I think a lot of people get turned off by the `Unsafe` in the name, maybe worth adding a sentence like this to the docs ("In the meantime I recommend...
Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs: diffrax ``` let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let...
With throw=False, EQX_ERROR=NAN and step to, this is what I see code ```python import os os.environ["EQX_ON_ERROR"] = "nan" import diffrax as dx import jax import jax.numpy as jnp from matplotlib...
The default actually errors with UBP which is why I changed to direct adjoint ``` ValueError: `adjoint=RecursiveCheckpointAdjoint()` does not support `UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` instead. ```
DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap....
1. That is something I want to investigate as well (and also organize more of it pushed to a fork for others to check), admittedly will take a little bit...