Jake Vanderplas
Jake Vanderplas
Hi, thanks for the report! I think this performance discrepancy comes from the fact that the batching rule for `dynamic_slice` is implemented in terms of `gather`, and `gather` is much...
You may be able to do what you want via the `donate_argnums`/`donate_argnames` parameter of `jax.jit`; see the [`jax.jit` documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) for a description. For example: ```python from functools import partial import...
There is a new experimental `ArrayRef` object that may serve the purposes of the question here: https://docs.jax.dev/en/latest/array_refs.html This should be part of the v0.7.1 release.
Closing becuase I don't think there's any other action to take here. Thanks for raising the issue!
It's hard to say much definitive without a full reproduction, including what operations `state.apply_fun` are doing, but in general it's not surprising that taking derivatives with respect to different inputs...
I'm having trouble running your code: can you point to what the `modeling` package is? These details help, but my first response still holds: differentiating with respect to different arguments...
A comment here: the wording implies that `None` has a special meaning – but my understanding is that the spec makes no requirements of what objects libraries use to represent...