jax
jax copied to clipboard
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
This problem has been bothering me for months... I really want to know how to fix it. When implementing NeRF with jax, I want to divide the data into batches,...
TPU deadlock
Hello, I am trying to train reformer model using Trax and JAX. The training fails on Google Colab because of memory limitation. When I run it on google cloud server...
I could not find an answer to this in the documentation: I would like to use one specific of multiple available devices throughout my program, which is not `jax.devices()[0]`. Is...
Consider the following code taken from the JAX repository: https://github.com/google/jax/blob/master/examples/differentially_private_sgd.py I'd like to understand how to profile the memory taken by the neural network in an epoch. I looked into:...
running that "python3 build.py --enable_cuda" No library found under: /usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudart.so.11.1 but what I have there is libcudart.so.11.0 and libcudart.so.11.1.74. My cuda version is 11.1 and it passed the test.
Hello, I am conducting an experiment using a server with 4 GPUs. I just run the script spmd_mnist_classifier_fromscratch.py under the example folder. Then I used Nvidia system to profile the...
Is there an efficient GPU (i. e. XLA one-op) implementation of floating-point division available in JAX where division by `0` returns `0`, instead of ~raising an error~ `inf`? In my...
Working on local GPU RTX 2060 super, Cuda 11.1, and got this error. jax has been installed successfully with the following ``` pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html ```...
I'm working on training script (https://github.com/rwightman/efficientnet-jax/blob/master/tf_linen_train.py) based on Flax Linen ImageNet example (https://github.com/google/flax/blob/master/linen_examples/imagenet/imagenet_lib.py). It was working great on a system with 2 x Titan RTX. The same setup on 2...
The following MWE trains a simple neural ODE model with gradient descent to match a 2-D dynamical system (Van der Pol oscillator) with sampled data along a single trajectory. Each...