diffrax
diffrax copied to clipboard
Have SDEs using `RecursiveCheckpointAdjoint` rather than `DirectAdjoint`
This relies on JAX implementing a way to detect symbolic zero tangents to a custom_vjp
.
This relies on JAX implementing a way to detect symbolic zero tangents to a custom_vjp
.