Patrick Kidger
Patrick Kidger
This is quite nice. A few high-level comments to begin with. Can you: - tweak the introduction to follow the same style as the current examples? In particular a bit...
Yep, checking if there was interest before polishing this definitely makes a lot of sense. Regarding the pre-commit hooks: see [CONTRIBUTING.md](https://github.com/patrick-kidger/diffrax/blob/main/CONTRIBUTING.md). These are a set of scripts that run when...
For `pytest` -- I think this is due to an upstream change in JAX. If you rebase/merge the current latest version of Diffrax, do you still get an error? (And...
And merged! Thank you for contributing.
This was a deliberate choice, as it felt more natural (at least to me) to have `evaluate` and `derivative` methods, rather than `__call__` and `derivative` methods. In reality the `derivative`...
I don't know if it's been considered, but one other option are run-time type-checkers. Personally I never use static type-checkers as I find jumping through their hoops more pain than...
FWIW, once the [jaxtyping rewrite](https://github.com/google/jaxtyping/tree/rewrite) goes in ¹ then jaxtyping will actually be PEP-compliant. It shouldn't actually need any special support from either runtime type checkers or static type checkers....
So `filter_pmap` definitely has a small amount of extra overhead, above that provided by manually combining `jax.pmap` and `eqx.{partition,combine}`. This is just because `filter_pmap` is more general, so it needs...
Interesting! Yep, that's definitely not desired -- it sounds like your hardware might just be beefy enough that you're able to do the whole forward evaluation in a time that's...
Oh also - trying a profiler like [py-spy](https://github.com/benfred/py-spy) may be interesting as well. (Rather than just me speculating about which bit is slow.)