jax
jax copied to clipboard
Low calculation performance compared to autograd elementwise_grad
Description
I use autograd to calculate partial derivatives of functions of two variables (x, y). Due to the end of support for autograd, I'm trying to get the same results using jax.
These functions have the form:
$$\nabla^4 w = \cfrac{\partial^4 w}{\partial x^4} + 2\cfrac{\partial^4 w}{\partial x^2\partial y^2} + \cfrac{\partial^4 w}{\partial y^4}$$
where $w = w(x,y)$.
I use similar functions obtained by automatic differentiation in other parts of the program as wrappers, and then to obtain the final results I substitute the values of the NumPy arrays.
I haven't found a way to port this type of two-variable functions from autograd to jax with similar performance.
Examples:
autograd (ex1.py)
import numpy as np
from autograd import elementwise_grad as egrad
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
egrad(egrad(egrad(egrad(w, dx), dx), dx), dx)(x, y)
+ 2 * egrad(egrad(egrad(egrad(w, dx), dx), dy), dy)(x, y)
+ egrad(egrad(egrad(egrad(w, dy), dy), dy), dy)(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = np.arange(10_000, dtype=np.float64)
y = np.arange(10_000, dtype=np.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex1.py }
Days : 0
Hours : 0
Minutes : 0
Seconds : 0
Milliseconds : 813
Ticks : 8130392
TotalDays : 9,41017592592593E-06
TotalHours : 0,000225844222222222
TotalMinutes : 0,0135506533333333
TotalSeconds : 0,8130392
TotalMilliseconds : 813,0392
jax (ex2.py)
import jax
from jax import grad, vmap
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
dx, dy = 0, 1
def nabla4(w):
def fn(x, y):
return (
vmap(grad(grad(grad(grad(w, dx), dx), dx), dx))(x, y)
+ 2 * vmap(grad(grad(grad(grad(w, dx), dx), dy), dy))(x, y)
+ vmap(grad(grad(grad(grad(w, dy), dy), dy), dy))(x, y)
)
return fn
def f(x, y):
return x**4 + 2 * x**2 * y**2 + y**4
x = jnp.arange(10_000, dtype=jnp.float64)
y = jnp.arange(10_000, dtype=jnp.float64)
w = [f] * 100 # In a real program, the elements of the list are various functions.
r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex2.py }
Days : 0
Hours : 0
Minutes : 0
Seconds : 6
Milliseconds : 906
Ticks : 69064939
TotalDays : 7,99362719907407E-05
TotalHours : 0,00191847052777778
TotalMinutes : 0,115108231666667
TotalSeconds : 6,9064939
TotalMilliseconds : 6906,4939
The program using jax is almost 9x slower than the version using autograd. In more complicated programs the differences are much greater.
System info (python version, jaxlib version, accelerator, etc.)
jaxlib: 0.4.30
numpy: 1.26.4
python: 3.10.13 | packaged by conda-forge | (tags/v3.10.13-25-g07fbd8e9251-dirty:07fbd8e9251, Dec 28 2023, 15:38:17) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Vero', release='10', version='10.0.22631', machine='AMD64')