Patrick Kidger

Results 1451 comments of Patrick Kidger

The overall computational cost will be `batch size × greatest number of steps for any batch element`. For example if there are a batch of 2 elements, the first batch...

So `filter_shard` is just a thin wrapper around wsc: https://github.com/patrick-kidger/equinox/blob/8191b113df5d985720e86c0d6292bceb711cbe94/equinox/_sharding.py#L40 it's designed to make it (a) easier to handle static arguments and (b) to unify device/sharding information. So indeed it's...

> Any idea why XLA complains when using filter_shard or lax.with_sharding_constraint inside a JIT'd function with PRNG keys? Maybe it's more of a question/issue for JAX or XLA folks... That...

Actually, I think JAX is exactly that clever :) Optimizing `x+0` to just `x` is a simple optimization that XLA should perform for us. That said I'd be happy to...

Hmm, that's unfortunate if so. Quax is still a fairly experimental library, so I'd be happy to take suggestions on how we might adjust the internals to work around this....

I don't believe this is possible, unfortunately. I think for this I would recommend using JAX, and in particular the interpolation routines in [Diffrax](https://github.com/patrick-kidger/diffrax) as a better more featureful option.

I can't replicate your issue I'm afraid. Running: ```python import torch from jaxtyping import Float def simple_test_a(x: Float[torch.Tensor, "dim1"]) -> torch.Tensor: reveal_type(x) return x def simple_test_b(x: Float[torch.Tensor, "dim1"]) -> float:...

This sounds reasonable to me. IIUC codspeed is a service for recording the values of benchmarks, and is otherwise exactly the same as `pytest-benchmark`? If so, this all sounds reasonable...

You need to have JAX installed as well. jaxtyping only has JAX as an optional dependency, to support also being used with PyTorch etc.

Sorry, missed this question. Yes, jaxtyping no longer depends on JAX. The name is now for historical reasons only! The syntax Miles is using is correct.