jax
jax copied to clipboard
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Add functionality for "pure" callbacks Also avoids using CPP dispatch path when host callbacks are involved
[Pax] Support checkpoint policies in pipeline layer.
`JaxTestCase` now sets this to strict by default, so these annotations are no longer necessary: https://github.com/google/jax/blob/af18235ea30799ed2fc50e557ebba869e2bdbd41/jax/_src/test_util.py#L706-L713
Hi, I have a simple script below that compares runtimes for `pmap` vs `pjit`. I expected that the runtime for `pjit` with full data parallelism would be the same for...
To use the newly added bindings for use_globsl_device_ids in AllReduce
Allow collectives in manually sharded computations ... at least when the manual sharding applies to the whole mesh, because that's all that XLA can support right now. This is especially...
Hi, I'm reporting this issue with a disclaimer that it's probably not reproducible (sorry!) but hoping it's of some worth for the awesome JAX developers: We have a machine with...
The bug can be reproduced by the following code ``` import jax.numpy as jnp from jax.random import PRNGKey, normal from jax.lax import dot from jax.config import config config.update("jax_enable_x64", True) A...
Fix a bug where `get_backend` in host callback lowering was provided with a wrong platform
Hi, I'm trying to use call_tf in combination with jacrev and hence vmap. The primitive_batcher of type call_tf_p is not implemented. I tried to modify the source code of `jax/experimental/jax2tf/call_tf.py`...