jax icon indicating copy to clipboard operation
jax copied to clipboard

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Results 1164 jax issues
Sort by recently updated
recently updated
newest added

[XLA:Python] Add support for executing portable PJRT executables from Python. Add an optional `device` argument to execute that allows the user to trigger the `ExecutePortable` path. Plumb the `compile_portable_executable` compilation...

### Description Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed'...

bug

### Description The following code example works perfectly for ndim = 10, but Jax returns nan for ndim=100. It stars failing at ndim=15. I am working in double precision by...

bug

Please: - [X] Check for duplicate requests. - [X] Describe your goal, and if possible provide a code snippet with a motivating example. Hi all, Whilst working on a recent...

enhancement
contributions welcome

Take the pjit XLA lowering path for `Arrays`. In the test, `astype` happens in a sharded fashion without the round trip to host.

Putting this here and tagging myself @ericmjl so that I can remember this exists. To get `jax` into the hands of data scientists and machine learning researchers, `conda` installation would...

enhancement
build
contributions welcome
P2 (eventual)
NVIDIA GPU

```python3 import jax import jax.numpy as np x = np.arange(10) x = jax.device_put(x) print(x[[13]]) ``` This prints `[3]`, but it should actually throw an _out of bounds_ error like the...

documentation
P2 (eventual)
NVIDIA GPU

Running the following example with `--jit` leaks memory (duplicates x?); GPU memory utilization increases by 1 GiB per iteration. Without `--jit` it works fine. ```python3 #!/usr/bin/env python3 import argparse import...

bug
P0 (urgent)
NVIDIA GPU

I'm looking to implement custom GPU ops similar to how tensorflow allows for defining custom jvps. Is there a similar tutorial/guide on how feasible this will be with jax?

P2 (eventual)
NVIDIA GPU