jax
jax copied to clipboard
Performance Issue Report: JAX Slower Than Autograd on GPU and CPU Setups
Description
Introduction:
This report outlines a performance issue observed with JAX on both GPU and CPU hardware setups. The purpose of this report is to provide detailed feedback to the JAX development team to aid in identifying potential areas for optimization. Observed Performance Issues:
- GPU Performance:
- JAX is significantly slower than expected when compared to Autograd on identical tasks, showing a minimum of 5x slower performance on NVIDIA GPUs.
- CPU Performance:
- Similar underperformance is observed on Intel Core i7 CPUs, where JAX operations are markedly slower than those performed with Autograd.
Steps to Reproduce:
- Set up the environment with specified hardware and software versions.
- Run benchmark tests including matrix operations, gradient calculations (ADAM).
- Compare execution times of JAX and Autograd.
Expected Behavior:
JAX should exhibit comparable or better performance than Autograd given its design for high-performance machine learning tasks, especially on platforms supporting GPU acceleration.
Actual Behavior:
JAX underperforms significantly compared to Autograd across all tested hardware setups.
Attachments:
- Benchmarking Scripts:
ADAM (beta1: 0.95, beta2:0.99, epsilon: 0.001), BFGS, Newton-CG, CG (standard scipy.optimize.minimize configuration) on synthetic function:
def f(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * anp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * anp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * anp.sin(5 * anp.pi * (x[0] - 1.25)) * anp.sin(5 * anp.pi * (x[1] - 1.75))
term3_y = 3 * anp.sin(7 * anp.pi * (x[0] - 1.25)) * anp.sin(7 * anp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * anp.sin(3 * anp.pi * x[0]) * anp.sin(3 * anp.pi * x[1]) # Oscillating term
term5 = -5 * anp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -anp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * anp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
and
def f(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * jnp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * jnp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * jnp.sin(5 * jnp.pi * (x[0] - 1.25)) * jnp.sin(5 * jnp.pi * (x[1] - 1.75))
term3_y = 3 * jnp.sin(7 * jnp.pi * (x[0] - 1.25)) * jnp.sin(7 * jnp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * jnp.sin(3 * jnp.pi * x[0]) * jnp.sin(3 * jnp.pi * x[1]) # Oscillating term
term5 = -5 * jnp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -jnp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * jnp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
Conclusion:
JAX is rich in features, but is slower than Autograd.
Recommendations:
- Conduct a thorough investigation into the causes of the observed performance bottlenecks.
Acknowledgments:
Thank you to the developers of JAX for their ongoing efforts and contributions to the open-source community.
System info (python version, jaxlib version, accelerator, etc.)
- Hardware: Intel Core i7-9750H CPU, 16GB DDR4 RAM, NVIDIA GTX 1650 GPU, NVIDIA Tesla L4 GPU (used in Colab Pro)
- Software: JAX version 0.4.26, Jaxlib version 0.4.26, Python version 3.9.15, Jupyter notebook 5.7.2.
- Comparison Reference: Autograd version 1.6.2
- Additional info:
JAX Available devices: [cuda(id=0)] Torch CUDA Available: True Torch CUDA Device Name: NVIDIA GeForce GTX 1650 (on Colab Nvidia L4) Torch Current CUDA Device ID: 0 Torch Number of GPUs: 1
What happens if you wrap the jax function in jax.jit?
Also, can you include details on how you ran the benchmarks? Keep in mind these tips to make sure you're measuring what you think you're measuring when running benchmarks of JAX code: https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code
Hi Jake,
I read the documentation you mentioned, I believe I haven't miss anything important, since my code is simple and trivial.
-
I wrapped the function using one of two ways:
- The decorator:
@jit - Directly on the function:
jax.jit(f)
- The decorator:
-
I moved
x0to the GPU as follows:x = jnp.array([-10., -80]) x0 = device_put(x, jax.devices('gpu')[0]) -
I ran identical code samples using
jnpandanp. Theanpversion completed in under a second, while thejnpversion has been running for over 10 minutes (I reduced the number of iteration to a minimum number to finish the test and take the screenshots, unlikeanpwhich broke the loop upon meeting convergence criteria). -
Here is ADAM with
time:
def adam(grad_func, x0, alpha=0.01, beta1=0.95, beta2=0.99, epsilon=1e-3):
start_time = time.time()
max_iter=500
initial_function_value = f(x0)
initial_function_value = f(x0)
m = jnp.zeros_like(x0)
v = jnp.zeros_like(x0)
t = 0
x = x0
path = [x0]
#while True:
for i in range(max_iter):
grad = grad_func(x)
t += 1
m = beta1 * m + (1 - beta1) * grad
v = beta2 * v + (1 - beta2) * grad ** 2
m_hat = m / (1 - beta1 ** t)
v_hat = v / (1 - beta2 ** t)
x = x - alpha * m_hat / (jnp.sqrt(v_hat) + epsilon)
path.append(x)
if jnp.linalg.norm(grad) < epsilon or abs(f(x) - initial_function_value) <= epsilon:
break
initial_function_value = f(x)
end_time = time.time() # End timing
execution_time = end_time - start_time # Calculate total execution time
return x, path, execution_time
I am using the time library for rough performance measurement. The function in question is simple, as described.
When I try benchmarking your original function using jax.jit, I find that JAX is 4x faster than autograd on both CPU and GPU for inputs of size 1000
import autograd.numpy as anp
import jax
import jax.numpy as jnp
def f_autograd(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * anp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * anp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * anp.sin(5 * anp.pi * (x[0] - 1.25)) * anp.sin(5 * anp.pi * (x[1] - 1.75))
term3_y = 3 * anp.sin(7 * anp.pi * (x[0] - 1.25)) * anp.sin(7 * anp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * anp.sin(3 * anp.pi * x[0]) * anp.sin(3 * anp.pi * x[1]) # Oscillating term
term5 = -5 * anp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -anp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * anp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
@jax.jit
def f_jax(x):
term1 = 0.5 * (x[0]**2 + (x[1] - 0.5)**2) # Central parabolic valley
# Nested valley 1 (deep, narrow)
term2_x = -4 * jnp.exp(-(x[0] + 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2_y = -8 * jnp.exp(-(x[0] - 0.75)**2 - 10 * (x[1] - 0.3)**2)
term2 = term2_x + term2_y
# Nested valley 2 (wide, shallow)
term3_x = 2 * jnp.sin(5 * jnp.pi * (x[0] - 1.25)) * jnp.sin(5 * jnp.pi * (x[1] - 1.75))
term3_y = 3 * jnp.sin(7 * jnp.pi * (x[0] - 1.25)) * jnp.sin(7 * jnp.pi * (x[1] - 1.75))
term3 = 0.2 * (term3_x + term3_y) # Adjust coefficient for shallower valley
term4 = 3 * jnp.sin(3 * jnp.pi * x[0]) * jnp.sin(3 * jnp.pi * x[1]) # Oscillating term
term5 = -5 * jnp.exp(-(x[0] + 1)**2 - (x[1] + 1)**2) # Deeper global minimum
term6 = -jnp.exp(-(x[0] - 1.5)**2 - (x[1] - 1.5)**2) # Local minimum
term7 = -2 * jnp.exp(-(x[0] + 2)**2 - (x[1] - 2)**2) # Local minimum
return term1 + term2 + term3 + term4 + term5 + term6 + term7
shape = (2, 1000)
x_autograd = anp.ones(shape)
%timeit f_autograd(x_autograd)
# 797 µs ± 440 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
x_jax = jnp.ones(shape)
_ = f_jax(x_jax) # trigger compilation
%timeit f_jax(x_jax).block_until_ready()
# 141 µs ± 17.1 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
This is on a Colab CPU runtime, using the built-in %timeit magic function. On a Colab T4 GPU, the timings I get are:
374 µs ± 73.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
92.9 µs ± 3.18 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
If you could include your full end-to-end benchmark script, including all imports, array definitions, function definitions, and function calls, I may be able to comment on why you are seeing different results.
Ah I think now I see, when I run your snippet I got this warning:
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuSPARSE installation found.
Version JAX was built against: 12200
Minimum supported: 12100
Installed version: 12002
The local installation version must be no lower than 12100. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
221 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
118 µs ± 3.85 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
And got :
JAX Available devices: [CpuDevice(id=0)]
When running:
devices = jax.devices()
print("JAX Available devices:", devices)
But when I import jaxlib this warning disappears, and the performance drops to become almost equal to autograd (btw, never seen this warning before, maybe because I was importing jaxlib, but why really?)
227 µs ± 19.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) -- autograd
182 µs ± 49.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each) -- Jax
and still :
JAX Available devices: [CpuDevice(id=0)]
This is my nvcc --version:
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Feb__7_19:32:13_PST_2023
Cuda compilation tools, release 12.1, V12.1.66
Build cuda_12.1.r12.1/compiler.32415258_0
And my jax and jaxlib versions are 0.4.26