Patrick Kidger
Patrick Kidger
Hey there! Hmm, I don't think I understand this one. Doing just `python -c 'import jaxtyping'` shouldn't load pytest at all. In fact with this not even `jaxtyping._pytest_plugin` is loaded...
Ah, hmm. Perhaps it'll be the way that `jaxtyping` checks if JAX/Equinox are available. I'm speculating here as `jaxtyping` itself is tiny. If you can, try benchmarking what happens when...
Ah, interesting! I've not seen `filter_spec` passed across a JIT boundary like this before. Usually this is something created *before* a JIT boundary, so that `params` can be split into...
Closing as resolved in #1038!
I have no further information beyond what @johannahaffner and @johnviljoen have already shared, hopefully they can help! :)
I'm afraid there is no special support for these. I don't really recommend using them -- PyTorch itself seems to have mostly given up on adding further support for them....
You should separate your numerical JAX bit and the general 'software' bit of plotting. Do all JAX operations inside a jit'd region, then pass their output to the rest of...
My guess is that this is a JAX bug, we don't do much at this level of the stack. That aside, would definitely need a MWE to diagnose this one.
Through process of substituting definitions, then deleting as much as I can, and repeating: I've managed to reduce the MWE down to one that no longer requires Optimistix. ```python import...
Opened a JAX-only MWE in https://github.com/jax-ml/jax/issues/31448 :)