jax
jax copied to clipboard
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
[XLA:Python] Add support for executing portable PJRT executables from Python. Add an optional `device` argument to execute that allows the user to trigger the `ExecutePortable` path. Plumb the `compile_portable_executable` compilation...
### Description Pooling operations sometimes fail to convert. It looks like a None dimension is sometimes slipping through the cracks. The bug depends on the stride value... I'm using 'framed'...
### Description The following code example works perfectly for ndim = 10, but Jax returns nan for ndim=100. It stars failing at ndim=15. I am working in double precision by...
Please: - [X] Check for duplicate requests. - [X] Describe your goal, and if possible provide a code snippet with a motivating example. Hi all, Whilst working on a recent...
[mhlo] Move CreateTokenOp from HLO to CHLO.
Take the pjit XLA lowering path for `Arrays`. In the test, `astype` happens in a sharded fashion without the round trip to host.
Putting this here and tagging myself @ericmjl so that I can remember this exists. To get `jax` into the hands of data scientists and machine learning researchers, `conda` installation would...
```python3 import jax import jax.numpy as np x = np.arange(10) x = jax.device_put(x) print(x[[13]]) ``` This prints `[3]`, but it should actually throw an _out of bounds_ error like the...
Running the following example with `--jit` leaks memory (duplicates x?); GPU memory utilization increases by 1 GiB per iteration. Without `--jit` it works fine. ```python3 #!/usr/bin/env python3 import argparse import...
I'm looking to implement custom GPU ops similar to how tensorflow allows for defining custom jvps. Is there a similar tutorial/guide on how feasible this will be with jax?