Jake Vanderplas

Results 424 comments of Jake Vanderplas

It looks like one possible workaround for now is to use an optimization barrier: ```python from jax._src.ad_checkpoint import _optimization_barrier def build_sparse_linear_operator(): nonzeroes = sparse.eye(n).indices # shape (n,2) def product(other): nz...

That said, it's probably worth filing an XLA bug for this. It should be something that the compiler handles automatically: https://github.com/openxla/xla

Sounds like this change I did a couple years ago: https://github.com/jakevdp/PythonicPerambulations/commit/33a258d05ec1ad48445568094887887da1e413e8

Awesome, consider it assigned 😁