Jake Vanderplas

Results 430 comments of Jake Vanderplas

Thanks for the report - we definitely should have a docstring for that function. Fortunately, it's pretty straightforward: it basically calls `functools.reduce()` over `tree_leaves(tree)`: https://github.com/google/jax/blob/94aade035a5fdeb2d3ed6f1744fcf1fa16240b8c/jax/_src/tree_util.py#L248-L254 Would you be interested in...

We haven't implemented this because the semantics of `np.put_along_axis` are to modify the array in-place, and this is not possible in JAX because JAX arrays are immutable. I suspect you...

One idea may be to define `jax.numpy.put_along_axis`, but add an extra `inplace` keyword that defaults to `True`, such that the function errors with the default value. Users could set `inplace=False`...

No, but note that you should be able to use [`jnp.ndarray.at[]`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) to do anything that you might do with `put_along_axis`.

Sure, I'd review a `put_along_axis` PR

Sounds good - feel free to take a look!

You might get more traction asking about this in the torch project.

This looks like a garbage collection issue. In each iteration, once `inputs` goes out of scope, it doesn't immediately get deleted. Rather, its CPython reference count goes to zero, and...

I think this is working as expected: when you run these lines: ```python iters = 500_000 key, *data_keys = random.split(key, iters + 1) ``` you're creating a list `data_keys` with...

(I just realized my comment is the same solution @mattjj offered above... sorry if there's something I'm missing)