jax icon indicating copy to clipboard operation
jax copied to clipboard

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Results 1164 jax issues
Sort by recently updated
recently updated
newest added

This problem has been bothering me for months... I really want to know how to fix it. When implementing NeRF with jax, I want to divide the data into batches,...

P1 (soon)
NVIDIA GPU

Hello, I am trying to train reformer model using Trax and JAX. The training fails on Google Colab because of memory limitation. When I run it on google cloud server...

I could not find an answer to this in the documentation: I would like to use one specific of multiple available devices throughout my program, which is not `jax.devices()[0]`. Is...

enhancement
question
P1 (soon)
NVIDIA GPU

Consider the following code taken from the JAX repository: https://github.com/google/jax/blob/master/examples/differentially_private_sgd.py I'd like to understand how to profile the memory taken by the neural network in an epoch. I looked into:...

question
P2 (eventual)
NVIDIA GPU

running that "python3 build.py --enable_cuda" No library found under: /usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudart.so.11.1 but what I have there is libcudart.so.11.0 and libcudart.so.11.1.74. My cuda version is 11.1 and it passed the test.

P0 (urgent)
NVIDIA GPU

Hello, I am conducting an experiment using a server with 4 GPUs. I just run the script spmd_mnist_classifier_fromscratch.py under the example folder. Then I used Nvidia system to profile the...

enhancement
XLA
P3 (no schedule)
NVIDIA GPU

Is there an efficient GPU (i. e. XLA one-op) implementation of floating-point division available in JAX where division by `0` returns `0`, instead of ~raising an error~ `inf`? In my...

question
P2 (eventual)
NVIDIA GPU

Working on local GPU RTX 2060 super, Cuda 11.1, and got this error. jax has been installed successfully with the following ``` pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html ```...

P0 (urgent)
NVIDIA GPU

I'm working on training script (https://github.com/rwightman/efficientnet-jax/blob/master/tf_linen_train.py) based on Flax Linen ImageNet example (https://github.com/google/flax/blob/master/linen_examples/imagenet/imagenet_lib.py). It was working great on a system with 2 x Titan RTX. The same setup on 2...

bug
P0 (urgent)
NVIDIA GPU

The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each...

P1 (soon)
NVIDIA GPU