jax icon indicating copy to clipboard operation
jax copied to clipboard

Low calculation performance compared to autograd elementwise_grad

Open krysros opened this issue 9 months ago • 1 comments

Description

I use autograd to calculate partial derivatives of functions of two variables (x, y). Due to the end of support for autograd, I'm trying to get the same results using jax.

These functions have the form:

$$\nabla^4 w = \cfrac{\partial^4 w}{\partial x^4} + 2\cfrac{\partial^4 w}{\partial x^2\partial y^2} + \cfrac{\partial^4 w}{\partial y^4}$$

where $w = w(x,y)$.

I use similar functions obtained by automatic differentiation in other parts of the program as wrappers, and then to obtain the final results I substitute the values ​​of the NumPy arrays.

I haven't found a way to port this type of two-variable functions from autograd to jax with similar performance.

Examples:

autograd (ex1.py)

import numpy as np
from autograd import elementwise_grad as egrad


dx, dy = 0, 1


def nabla4(w):
    def fn(x, y):
        return (
            egrad(egrad(egrad(egrad(w, dx), dx), dx), dx)(x, y)
            + 2 * egrad(egrad(egrad(egrad(w, dx), dx), dy), dy)(x, y)
            + egrad(egrad(egrad(egrad(w, dy), dy), dy), dy)(x, y)
        )

    return fn


def f(x, y):
    return x**4 + 2 * x**2 * y**2 + y**4


x = np.arange(10_000, dtype=np.float64)
y = np.arange(10_000, dtype=np.float64)


w = [f] * 100  # In a real program, the elements of the list are various functions.

r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex1.py }

Days              : 0
Hours             : 0
Minutes           : 0
Seconds           : 0
Milliseconds      : 813
Ticks             : 8130392
TotalDays         : 9,41017592592593E-06
TotalHours        : 0,000225844222222222
TotalMinutes      : 0,0135506533333333
TotalSeconds      : 0,8130392
TotalMilliseconds : 813,0392

jax (ex2.py)

import jax
from jax import grad, vmap
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)


dx, dy = 0, 1


def nabla4(w):
    def fn(x, y):
        return (
            vmap(grad(grad(grad(grad(w, dx), dx), dx), dx))(x, y)
            + 2 * vmap(grad(grad(grad(grad(w, dx), dx), dy), dy))(x, y)
            + vmap(grad(grad(grad(grad(w, dy), dy), dy), dy))(x, y)
        )

    return fn


def f(x, y):
    return x**4 + 2 * x**2 * y**2 + y**4


x = jnp.arange(10_000, dtype=jnp.float64)
y = jnp.arange(10_000, dtype=jnp.float64)


w = [f] * 100  # In a real program, the elements of the list are various functions.

r = [nabla4(f)(x, y) for f in w]
(idp) PS C:\Users\kryst\Projects\example> Measure-Command { python ex2.py }

Days              : 0
Hours             : 0
Minutes           : 0
Seconds           : 6
Milliseconds      : 906
Ticks             : 69064939
TotalDays         : 7,99362719907407E-05
TotalHours        : 0,00191847052777778
TotalMinutes      : 0,115108231666667
TotalSeconds      : 6,9064939
TotalMilliseconds : 6906,4939

The program using jax is almost 9x slower than the version using autograd. In more complicated programs the differences are much greater.

System info (python version, jaxlib version, accelerator, etc.)

jaxlib: 0.4.30
numpy:  1.26.4
python: 3.10.13 | packaged by conda-forge | (tags/v3.10.13-25-g07fbd8e9251-dirty:07fbd8e9251, Dec 28 2023, 15:38:17) [MSC v.1929 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', node='Vero', release='10', version='10.0.22631', machine='AMD64')

krysros avatar May 26 '24 10:05 krysros