Jake Vanderplas
Jake Vanderplas
I don't know of a good solution. XLA is not well set up for general unstructured sparse computation, unfortunately. Now, if you have block sparsity then you may be able...
See https://github.com/google/jax/pull/23674 for an example of an efficient block-sparse approach for one particular operation of interest.
Thanks for the report – this issue is tracked in #8755 To be quite honest, we're not putting much effort into the `jax.experimental.sparse` code these days, so this issue is...
Closing as a duplicate of #8755
I think the issue here is that your gradients are very close to zero, so very small absolute deviations become relatively large relative variations: ```python grad_jit_result = jax.grad(jax.jit(recon_loss))(example_batch) grad_result =...
Hmm, that's strange indeed. In general we don't expect JIT-compiled versions of functions to have bitwise-identical outputs to the non-compiled versions: any time you rearrange or fuse floating point operations,...
Does something like this answer help you? https://stackoverflow.com/questions/60019006/can-we-plot-image-data-in-altair
(answered in #22050)
Does this mean that libraries which don't explicitly define `__iadd__` are out of compliance? JAX is an example: `__iadd__` is undefined, so it falls back to the `__add__` behavior, which...
So effectively then, the array API specification is saying "compilant libraries must overload `__iadd__`", because the behavior when falling back to a compilant implementation of `__add__` is non-compliant.