Jake Vanderplas

Results 642 comments of Jake Vanderplas

Thanks for the report. I'm not sure what's going on, but it seems others are also noticing this: https://stackoverflow.com/questions/78486071/why-does-jax-compilation-time-grow-with-vmap-batch-size

I can repro on a Colab A100; thought it somehow might have to do with constant folding but even passing `fmat` as an argument and defining `one` and `ones` in...

We only finally removed the `arr.device()` method in JAX v0.4.27 – to avoid confusion for users I think we should wait for one more release (0.4.29) before we add the...

Thanks for the report! This looks like a float precision issue. Scipy uses 64-bit precision, while JAX uses 32-bit precision by default. If you enable 64-bit precision in JAX, you...

Hi! Thanks for the request – we've discussed this previously in #14802, and I think the answer there still reflects the thinking of the core team.

Thanks for the report! This code is a few years old now and the author is no longer working on the JAX project. I took a look and I found...

For what it's worth, I think the current behavior is defensible: e.g. if you have 64 shards that all error, it's not terrible to only see one copy of the...

+1: we've run into this with `ml_dtypes.bfloat16`, which is why we wrote our own `finfo` wrapper: https://github.com/jax-ml/ml_dtypes/blob/main/ml_dtypes/_finfo.py This is an occasional gotcha for users. A public `finfo`/`iinfo` registration mechanism for...

Thanks for the question! JAX's sparse-sparse matmul is pretty inefficient, but for good reason: for completely unstructured sparsity when the indices are not known statically, that's the worst-case non-aggregated nse...

Interesting! For what it's worth, this example illustrates the fundamental challenge with building a general sparse computing API. What is meant by "sparsity" is very specific to each particular context:...