Owen L

Results 338 comments of Owen L

> I can't comment on the others, but this one is expected and (still) a general limitation in JAX. Yea, I get that that error is a valid error, maybe...

> And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass! Smh why does google not use a tiny bit...

> Happy to take a PR on this one! ;) As a strict adherent to software best practices I will ~~make a new PR for this~~ roll this into my...

I get them to be a lot closer by using `UnsafeBrownianPath`, which has less overhead than VBT. Diffrax is still a bit slower with this change on my machine, but...

I think a lot of people get turned off by the `Unsafe` in the name, maybe worth adding a sentence like this to the docs ("In the meantime I recommend...

Diffrax has a lot more checking/shaping/logging than the default implementation. You can see it reflected in the jaxprs: diffrax ``` let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let...

With throw=False, EQX_ERROR=NAN and step to, this is what I see code ```python import os os.environ["EQX_ON_ERROR"] = "nan" import diffrax as dx import jax import jax.numpy as jnp from matplotlib...

The default actually errors with UBP which is why I changed to direct adjoint ``` ValueError: `adjoint=RecursiveCheckpointAdjoint()` does not support `UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` instead. ```

DirectAjoint does slow things down, but not all the way. If I switch to a branch that allows for UBP + recursive adjoint, it's faster but still around ~4x gap....

1. That is something I want to investigate as well (and also organize more of it pushed to a fork for others to check), admittedly will take a little bit...