Jake Vanderplas
Jake Vanderplas
Thanks for the report - we definitely should have a docstring for that function. Fortunately, it's pretty straightforward: it basically calls `functools.reduce()` over `tree_leaves(tree)`: https://github.com/google/jax/blob/94aade035a5fdeb2d3ed6f1744fcf1fa16240b8c/jax/_src/tree_util.py#L248-L254 Would you be interested in...
We haven't implemented this because the semantics of `np.put_along_axis` are to modify the array in-place, and this is not possible in JAX because JAX arrays are immutable. I suspect you...
One idea may be to define `jax.numpy.put_along_axis`, but add an extra `inplace` keyword that defaults to `True`, such that the function errors with the default value. Users could set `inplace=False`...
No, but note that you should be able to use [`jnp.ndarray.at[]`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) to do anything that you might do with `put_along_axis`.
Sure, I'd review a `put_along_axis` PR
Sounds good - feel free to take a look!
You might get more traction asking about this in the torch project.
This looks like a garbage collection issue. In each iteration, once `inputs` goes out of scope, it doesn't immediately get deleted. Rather, its CPython reference count goes to zero, and...
I think this is working as expected: when you run these lines: ```python iters = 500_000 key, *data_keys = random.split(key, iters + 1) ``` you're creating a list `data_keys` with...
(I just realized my comment is the same solution @mattjj offered above... sorry if there's something I'm missing)