Owen L
Owen L
My go to approach (not sure if it is the simplest) would be to make a custom solver. If you define a custom step function, that returns a y update...
I think the code can be further simplified code ```python from typing import Type import equinox as eqx import jax import jax.numpy as jnp import jaxtyping as jt # from...
Can confirm, if you make that change in the code provided it works
Given equinox's support and integration with Abstract classes, I think that's definitely the way to go. But for the sake of completeness, structural subtyping would also work here (Abstract/final? What...
Wouldn't that also mean bumping the jax dependency (a fair amount, from jax>=0.4.13)?
Is there a MVC? This code doesn't run for me
I like this idea as well, I've thought about suggesting speed regression tests to diffrax before (since that's come up in my work and others in the issues), but unifying...
Running this code on a fresh colab environment (and some block_until_readys), I see array is faster, but custom is slower ``` 24.5 µs ± 9.12 µs per loop (mean ±...
Open to discussion on best interface. The current caution around bias is ok but could be a footgun (although the original paper doing the reshaping is already kind of a...
> but as these are scalars that I never explicitly cast to jax arrays I thought I'd be fine. `self.len = jnp.sqrt(self.a**2 + self.b**2)` this casts it to a jax...