using jax like numba on GPUs and CPUs
Per our discussion, here is an example of the kernel we tried to code in jax but could not figure out how to do so efficiently. It has a couple of nested loops.
This is a cuda version, but the CPU one looks much the same modulo start and stride stuff which lays out the data on the GPU effectively.
# for GPU timing using Numba
@cuda.jit
def count_weighted_pairs_3d_cuda(
x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result):
start = cuda.grid(1)
stride = cuda.gridsize(1)
n1 = x1.shape[0]
n2 = x2.shape[0]
nbins = rbins_squared.shape[0]
for i in range(start, n1, stride):
px = x1[i]
py = y1[i]
pz = z1[i]
pw = w1[i]
for j in range(n2):
qx = x2[j]
qy = y2[j]
qz = z2[j]
qw = w2[j]
dx = px-qx
dy = py-qy
dz = pz-qz
wprod = pw*qw
dsq = dx*dx + dy*dy + dz*dz
k = nbins - 1
while dsq <= rbins_squared[k]:
cuda.atomic.add(result, k-1, wprod)
k -= 1
if k <= 0:
break
cc @mattjj @shoyer @aphearin
For cross-reference, when we did this in numba to start, we found a big bottleneck that appears to have been fixed subsequently.
xref: https://github.com/numba/numba/issues/4647
Also to be clear, we don't want to use autodiff in this kernel, which I realize is outside of the A part of JAX. ;p
Just to make our lives easy: can you expand the example to something self-contained and runnable (e.g., generate some random numbers for inputs) and show what timing you got with the Numba version?
There are two paths we can go down in general here: try to make a fast JAX version, or try to integrate with Numba. Both are fine ways to go! JAX and Numba have different strengths and you can and should be able to integrate both, but it would be remiss of us not to look at the JAX performance!
Sure here you go!
import time
import numpy as np
from numba import njit
@njit
def count_weighted_pairs_3d_cpu(
x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result):
n1 = x1.shape[0]
n2 = x2.shape[0]
nbins = rbins_squared.shape[0]
for i in range(n1):
px = x1[i]
py = y1[i]
pz = z1[i]
pw = w1[i]
for j in range(n2):
qx = x2[j]
qy = y2[j]
qz = z2[j]
qw = w2[j]
dx = px-qx
dy = py-qy
dz = pz-qz
wprod = pw*qw
dsq = dx*dx + dy*dy + dz*dz
k = nbins - 1
while dsq <= rbins_squared[k]:
result[k-1] += wprod
k -= 1
if k <= 0:
break
# parameters
npoints = 30000
n1 = npoints
n2 = npoints
n_try = 5
# make the data
rbins_squared = (np.logspace(
np.log10(0.1/1e3), np.log10(40/1e3), 20)**2).astype(np.float32)
rng = np.random.RandomState(seed=42)
x1, y1, z1, w1 = rng.uniform(size=(4, n1)).astype(np.float32)
x2, y2, z2, w2 = rng.uniform(size=(4, n1)).astype(np.float32)
# array init
result = np.zeros_like(rbins_squared)[:-1].astype(np.float32)
# do once to compile the code
r = count_weighted_pairs_3d_cpu(x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result)
t0 = time.time()
for _ in range(n_try):
r = count_weighted_pairs_3d_cpu(
x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared, result)
t0 = (time.time() - t0) / n_try
print("numba: %s seconds" % t0)
I get
(anl) clarence:Desktop beckermr$ python numba_cou_jit.py
numba: 2.117280626296997 seconds
Interesting!
You can write the same thing easily enough in JAX like this:
@jax.jit
def count_weighted_pairs_3d_cpu_jax(
x1, y1, z1, w1, x2, y2, z2, w2, rbins_squared):
dx = x1[:, None] - x2
dy = y1[:, None] - y2
dz = z1[:, None] - z2
wprod = w1[:, None] * w2
dsq = dx*dx + dy*dy + dz*dz
return jnp.sum(jnp.where(dsq <= rbins_squared[:, None, None], wprod, 0),
axis=(1, 2))
but on CPU it scales poorly with the number of points. It looks like the compiler is materializing the input to the sum reduction, which quadratic in the number of points. I think that's something we can fix; we need to fuse the elementwise computation with the reduction, which seems not to be implemented on CPU.
However on GPU it seems to perform reasonably. Did you try GPU?
I have not benchmarked the gpu version against Jax. Note however that this the simplest of a class of kernels. Idk if the version you posted will generalize to those.
Also interesting! It'd be nice to have a really crisp example if you have one.
(And as I mentioned before, making it easier to integrate with Numba is a fine strategy too.)
xref: #1870