Patrick Kidger

Results 1456 comments of Patrick Kidger

If the question is directed at me -- happy to help out with anything needed from the Equinox side / if you think it's best to upstream anything into Equinox...

Haha! In that case, SGTM!

So I'm guessing that you're running this on GPU? I recognise this `gather`-to-`dynamic_slice` rewrite as being an optimization that XLA:GPU (and to my knowledge, nothing else) performs. What surprises me...

> No, I'm on TPUs (v4-32 to be specific) Got it. I assume XLA:TPU has the same pass then. > I assume lax.gather/jnp.take is more informative to XLA that Embedding...

Hi @patrick-toulme ! Thanks for taking a look at this. At least at the jaxpr level, I think it is naive indexing that is promised to be in-bounds: `mode=GatherScatterMode.PROMISE_IN_BOUNDS` Whilst...

Sorry, it's not totally clear to me what change you're suggesting. Can you expand?

Ah, I see what you're saying! So I think much like QKV fusion, this would unfortunately be a backward-incompatible change. For specifically the purposes of sharding, then I think whatever...

I don't think this code works at runtime: ```python > import jax.numpy as jnp > (jnp.arange(3),) + jnp.arange(3) Traceback (most recent call last): File "", line 1, in File ".../site-packages/jax/_src/numpy/array_methods.py",...

Ah, right! :D So `jaxtyping.Array` is actually just a re-export of `jax.Array`. If there is an issue here then it'll either be in the typechecker or in JAX, I'm not...

I think I would need a MWE to see what's going on for you, I'm afraid! (If it means anything, I have used Optimistix as part of modelling chemical networks,...