jax icon indicating copy to clipboard operation
jax copied to clipboard

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Results 1164 jax issues
Sort by recently updated
recently updated
newest added

Add functionality for "pure" callbacks Also avoids using CPP dispatch path when host callbacks are involved

[Pax] Support checkpoint policies in pipeline layer.

`JaxTestCase` now sets this to strict by default, so these annotations are no longer necessary: https://github.com/google/jax/blob/af18235ea30799ed2fc50e557ebba869e2bdbd41/jax/_src/test_util.py#L706-L713

pull ready

Hi, I have a simple script below that compares runtimes for `pmap` vs `pjit`. I expected that the runtime for `pjit` with full data parallelism would be the same for...

bug

To use the newly added bindings for use_globsl_device_ids in AllReduce

pull ready

Allow collectives in manually sharded computations ... at least when the manual sharding applies to the whole mesh, because that's all that XLA can support right now. This is especially...

Hi, I'm reporting this issue with a disclaimer that it's probably not reproducible (sorry!) but hoping it's of some worth for the awesome JAX developers: We have a machine with...

bug
needs info
NVIDIA GPU

The bug can be reproduced by the following code ``` import jax.numpy as jnp from jax.random import PRNGKey, normal from jax.lax import dot from jax.config import config config.update("jax_enable_x64", True) A...

bug
P1 (soon)
NVIDIA GPU

Fix a bug where `get_backend` in host callback lowering was provided with a wrong platform

Hi, I'm trying to use call_tf in combination with jacrev and hence vmap. The primitive_batcher of type call_tf_p is not implemented. I tried to modify the source code of `jax/experimental/jax2tf/call_tf.py`...

enhancement