jax icon indicating copy to clipboard operation
jax copied to clipboard

using jax like numba on GPUs and CPUs

Open beckermr opened this issue 5 years ago • 9 comments

Per our discussion, here is an example of the kernel we tried to code in jax but could not figure out how to do so efficiently. It has a couple of nested loops.

This is a cuda version, but the CPU one looks much the same modulo start and stride stuff which lays out the data on the GPU effectively.

# for GPU timing using Numba
@cuda.jit
def count_weighted_pairs_3d_cuda(
        x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result):
    start = cuda.grid(1)
    stride = cuda.gridsize(1)

    n1 = x1.shape[0]
    n2 = x2.shape[0]
    nbins = rbins_squared.shape[0]

    for i in range(start, n1, stride):
        px = x1[i]
        py = y1[i]
        pz = z1[i]
        pw = w1[i]
        for j in range(n2):
            qx = x2[j]
            qy = y2[j]
            qz = z2[j]
            qw = w2[j]
            dx = px-qx
            dy = py-qy
            dz = pz-qz
            wprod = pw*qw
            dsq = dx*dx + dy*dy + dz*dz

            k = nbins - 1
            while dsq <= rbins_squared[k]:
                cuda.atomic.add(result, k-1, wprod)
                k -= 1
                if k <= 0:
                    break

cc @mattjj @shoyer @aphearin

beckermr avatar May 07 '20 20:05 beckermr

For cross-reference, when we did this in numba to start, we found a big bottleneck that appears to have been fixed subsequently.

xref: https://github.com/numba/numba/issues/4647

beckermr avatar May 07 '20 20:05 beckermr

Also to be clear, we don't want to use autodiff in this kernel, which I realize is outside of the A part of JAX. ;p

beckermr avatar May 07 '20 20:05 beckermr

Just to make our lives easy: can you expand the example to something self-contained and runnable (e.g., generate some random numbers for inputs) and show what timing you got with the Numba version?

There are two paths we can go down in general here: try to make a fast JAX version, or try to integrate with Numba. Both are fine ways to go! JAX and Numba have different strengths and you can and should be able to integrate both, but it would be remiss of us not to look at the JAX performance!

hawkinsp avatar May 07 '20 20:05 hawkinsp

Sure here you go!

import time

import numpy as np
from numba import njit


@njit
def count_weighted_pairs_3d_cpu(
        x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result):
    n1 = x1.shape[0]
    n2 = x2.shape[0]
    nbins = rbins_squared.shape[0]

    for i in range(n1):
        px = x1[i]
        py = y1[i]
        pz = z1[i]
        pw = w1[i]
        for j in range(n2):
            qx = x2[j]
            qy = y2[j]
            qz = z2[j]
            qw = w2[j]
            dx = px-qx
            dy = py-qy
            dz = pz-qz
            wprod = pw*qw
            dsq = dx*dx + dy*dy + dz*dz

            k = nbins - 1
            while dsq <= rbins_squared[k]:
                result[k-1] += wprod
                k -= 1
                if k <= 0:
                    break


# parameters
npoints = 30000
n1 = npoints
n2 = npoints
n_try = 5

# make the data
rbins_squared = (np.logspace(
    np.log10(0.1/1e3), np.log10(40/1e3), 20)**2).astype(np.float32)
rng = np.random.RandomState(seed=42)
x1, y1, z1, w1 = rng.uniform(size=(4, n1)).astype(np.float32)
x2, y2, z2, w2 = rng.uniform(size=(4, n1)).astype(np.float32)

# array init
result = np.zeros_like(rbins_squared)[:-1].astype(np.float32)

# do once to compile the code
r = count_weighted_pairs_3d_cpu(x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result)

t0 = time.time()
for _ in range(n_try):
    r = count_weighted_pairs_3d_cpu(
        x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result)
t0 = (time.time() - t0) / n_try

print("numba: %s seconds" % t0)

beckermr avatar May 07 '20 21:05 beckermr

I get

(anl) clarence:Desktop beckermr$ python numba_cou_jit.py 
numba: 2.117280626296997 seconds

beckermr avatar May 07 '20 21:05 beckermr

Interesting!

You can write the same thing easily enough in JAX like this:

@jax.jit
def count_weighted_pairs_3d_cpu_jax(
        x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared):
    dx = x1[:, None] - x2
    dy = y1[:, None] - y2
    dz = z1[:, None] - z2
    wprod = w1[:, None] * w2
    dsq = dx*dx + dy*dy + dz*dz
    return jnp.sum(jnp.where(dsq <= rbins_squared[:, None, None], wprod, 0),
                   axis=(1, 2))

but on CPU it scales poorly with the number of points. It looks like the compiler is materializing the input to the sum reduction, which quadratic in the number of points. I think that's something we can fix; we need to fuse the elementwise computation with the reduction, which seems not to be implemented on CPU.

However on GPU it seems to perform reasonably. Did you try GPU?

hawkinsp avatar May 08 '20 00:05 hawkinsp

I have not benchmarked the gpu version against Jax. Note however that this the simplest of a class of kernels. Idk if the version you posted will generalize to those.

beckermr avatar May 08 '20 01:05 beckermr

Also interesting! It'd be nice to have a really crisp example if you have one.

(And as I mentioned before, making it easier to integrate with Numba is a fine strategy too.)

hawkinsp avatar May 08 '20 01:05 hawkinsp

xref: #1870

froystig avatar May 14 '20 23:05 froystig