Jake Vanderplas
Jake Vanderplas
It looks like one possible workaround for now is to use an optimization barrier: ```python from jax._src.ad_checkpoint import _optimization_barrier def build_sparse_linear_operator(): nonzeroes = sparse.eye(n).indices # shape (n,2) def product(other): nz...
That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla
Sounds like this change I did a couple years ago: https://github.com/jakevdp/PythonicPerambulations/commit/33a258d05ec1ad48445568094887887da1e413e8
Awesome, consider it assigned 😁
I'm not actively working on this, and I don't know of anybody who is.
It looks like this has been fixed.
cc/ @tlu7
I would suggest not doing all of this manually, instead use ```python import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() ``` The advice above was for working around an issue with this built-in setup. If...
Can you say more about where you're seeing this issue? I just ran the following on a fresh Colab TPU runtime and got the expected output: ```python import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu()...
Thanks - that looks unrelated to this TPU issue. It would be helpful if you could open a [new discussion](https://github.com/google/jax/discussions/new?category=q-a), and if possible include a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example) of the...