Patrick Kidger

Results 267 comments of Patrick Kidger

Yep, this is expected. In JAX, all arrays must have a shape known at compile (JIT) time. But the number of steps may be dynamic (due to adaptive step size...

So don't forget that the stepping happens under JIT. That does limit the sorts of output we can perform like this. What you could do already is to use `jax.experimental.host_callback.{id_print,id_tap,call}`...

I think all of the above sound like reasonable benchmarks. I would suggest focusing primarily on the diffeq part of any problem though, rather than the entire ML system. It's...

Right, so "controlled differential equations" are a specific notion, and are written like `dy(t) = f(y(t)) dx(t)`; in some sense they are indeed forced by the derivative of `x`. This...

(If you're curious to know more about CDEs than I'd recommend Chapter 3 of the recently-released [On Neural Differential Equations](https://arxiv.org/abs/2202.02435) for an introduction. There's other references available too; I just...

> The most straightforward way would be to include the forcing term in vector_field - if one had a functional representation of it. Let's say that the forcing signal was...

Yep, that's exactly right.

Hmm. This looks like a JAX problem, rather than a Diffrax one. Namely that host callbacks and TPUs aren't interacting properly for some users. You can probably work around this...

I recently hit a similar issue. Just correcting my previous comment to give a better fix: ```python import sys for module_name, module in sys.modules.items(): if module_name.startswith("diffrax"): if hasattr(module, "branched_error_if"): module.branched_error_if...

Hey there! Hmm, this kind of scenario is rather annoying, I agree. I think this is probably a JAX issue, in that it is failing to properly clear its compilation...