Jake Vanderplas

Results 646 comments of Jake Vanderplas

It looks like one possible workaround for now is to use an optimization barrier: ```python from jax._src.ad_checkpoint import _optimization_barrier def build_sparse_linear_operator(): nonzeroes = sparse.eye(n).indices # shape (n,2) def product(other): nz...

That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla

Sounds like this change I did a couple years ago: https://github.com/jakevdp/PythonicPerambulations/commit/33a258d05ec1ad48445568094887887da1e413e8

Awesome, consider it assigned 😁

I'm not actively working on this, and I don't know of anybody who is.

I would suggest not doing all of this manually, instead use ```python import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() ``` The advice above was for working around an issue with this built-in setup. If...

Can you say more about where you're seeing this issue? I just ran the following on a fresh Colab TPU runtime and got the expected output: ```python import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()...

Thanks - that looks unrelated to this TPU issue. It would be helpful if you could open a [new discussion](https://github.com/google/jax/discussions/new?category=q-a), and if possible include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) of the...