Jake Vanderplas

Results 673 comments of Jake Vanderplas

JAX does not work with multiprocessing. The issue you're seeing is that one JAX process is pre-reserving the CUDA device memory, then another JAX process tries to do the same...

GPU allocations shouldn't happen on import, but on the first JAX operation (for example creating an array). There are two possibilities I think: (1) that has changed in a way...

That said, it **is** expected that JAX will warn about multiprocessing after import alone; that fork registration hook is registered at import time. The way to address that would be...

To be clear, are these the semantics you have in mind? ```python def stack_leaves(pytrees, axis): return jax.tree.map(lambda *xs: jnp.stack(xs, axis), pytrees) ```

For something like this, I'd probably lean toward recommending users implement what they need via existing API composability, rather than providing a new API for something that can already be...

A pytree cookbook would be an interesting idea! This idea also came up in #20594. @ayaka14732, is that something you'd be interested in thinking about?

Assigning @mattjj, who has thought a bit about how to support ragged operations more broadly in JAX.

Looks great - the last thing we need before getting this submitted is to squash the changes into a single commit.

Hi - JAX developer here – it looks like you're using Colab TPU; as of this writing (October 2023) Colab only provides very old TPU hardware, and is only compatible...

One disadvantage is that github is a bit of an outlier in providing readable diffs for notebooks; in other systems (critique, `git diff`, etc.) diffs are harder to read, and...