Patrick Kidger

Results 1456 comments of Patrick Kidger

Ah, interesting! Yeah, I'm definitely willing to believe that the nested `jax.jit` is the root cause of some funny business here.

With a bit of trickery, this is possible! ```python from typing import TypeVar, Union, TYPE_CHECKING from jaxtyping import Array, Float32, Shaped if TYPE_CHECKING: # needed for static type checking compatibility...

Haha, thank you! I like this idea, and I think the approach is using `jax.debug.inspect_array_sharding` is probably the correct one. Regarding the syntax, I don't think we can use `/`,...

I like your syntax suggestions! My only suggestion is to probably switch them around: `data_parallel!batch`. I know this isn't the convention in the literature, but so far the convention in...

Thanks @yashk2810 ! Okay so on balance, I think I'm inclined to go with the `Float[Array, "foo", PartitionSpec(...)]` syntax. The rationale is: - The shape and the parallelism strategy are...

Hey there! I'm afraid this isn't on my roadmap at the moment. I'd love to have it in, but it's not something I have time for myself right now. I...

Hmm, I'm a little mystified by this, because this was something I thought we added support for (https://github.com/patrick-kidger/equinox/issues/259, https://github.com/patrick-kidger/equinox/commit/c5fc44f4acff02f1b2c24f5f39f009c1b5ff5967). Indeed in the line just above your error, we have an...

Right! So I think what you're trying to do here is reasonable. I'd be happy to take a PR adjusting this. (Maybe we just consider all kinds of JAX and...

Make sure to wrap your programs in `jax.jit` or `equinox.filter_jit`. This is an important thing to do for all JAX programs if you want good performnace :) (Once this has...

Hmm, this looks like an upstream bug in JAX. I'd suggest opening an issue over there. Here's a minimal repro without using jaxtyping: ```python import beartype import jax # Fails...