jax
jax copied to clipboard
Conditional array update on GPU using jnp.where vs fori_loop
Description:
Hi!
I am fairly new to using JAX. I have been trying to update 1 or more entries of a 1D array based on some condition inside a jax.jit
and jnp.vectorize
function. I managed to find a very fast way of doing this on CPU, however when I tested on GPU, it suffered from extreme slow down (around x1000).
For more context, the way which seems to be fastest on CPU is similar to this code,
import jax
from jax.lax import cond, fori_loop
import jax.numpy as jnp
@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)")
def fun_vectorized(r, m, n, dr):
"""
m.size == n.size
m and n may have repeating values
N_max and M_max values are found using some functions
"""
def update(i, args):
"""Updates the output if required."""
alpha, N, result, out = args
idx = jnp.where(jnp.logical_and(m[i] == alpha, n[i] == N), i, -1)
def falseFun(args):
""" Do nothing."""
_, _, out = args
return out
def trueFun(args):
""" Update the value at idx."""
idx, result, out = args
out = out.at[idx].set(result)
return out
out = cond(idx >= 0, trueFun, falseFun, (idx, result, out))
return (alpha, N, result, out)
def body_inner(N, args):
alpha, out = args
result = find_some_array() # for this issue irrelevant
# Update array
_, _, _, out = fori_loop(0, m.size, update, (alpha, N, result, out))
return (alpha, out)
def body(alpha, out):
# Find max value of n corresponding to alpha
# This requires another function which is not relevant
N_max = find_Nmax()
# Loop over required n values
_, out = fori_loop(
0, (N_max + 1).astype(int), body_inner, (alpha, out)
)
return out
out = jnp.zeros(m.size)
M_max = jnp.max(m)
# Loop over different m values
out = fori_loop(0, (M_max + 1).astype(int), body, (out))
return out
To be fair, only the update
function is relevant to this issue, but I would like to include some other part of the code to show the architecture of my overall function. Since GPUs are not good at conditional branching, I changed my code to look like this,
@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)")
def fun_vectorized(r, m, n, dr):
def body_inner(N, args):
alpha, out = args
result = find_some_array()
# Find which indices to update
mask = jnp.logical_and(m == alpha, n == N)
out = jnp.where(mask, result, out)
return (alpha, out)
def body(alpha, out):
N_max = find_Nmax()
# Loop over required n values
_, out = fori_loop(
0, (N_max + 1).astype(int), body_inner, (alpha, out)
)
return out
out = jnp.zeros(m.size)
M_max = jnp.max(m)
# Loop over different m values
out = fori_loop(0, (M_max + 1).astype(int), body, (out))
return out
I noticed that using update
function with a fori_loop is really fast compared to jnp.where
counterpart on CPU by 4-5 times. I would think they perform the same, but that is not the case. I can understand the slow down due to jax.lax.cond
on GPU, but I was wondering if there is a better way of implementing what I am trying to do? And also, why is jnp.where
with 3 arguments is slower than setting new values to an array? My guess is creating the copy of array instead of in-place update, but I am not sure.
Some other requirements for my application:
I need to use jnp.where
or update functions because there can be 0 or multiple indices which satisfy the condition (for strictly 1 index case is way easier to handle using index = jnp.sum(jnp.where(jnp.logical_and(m == alpha, n == N), jnp.arange(m.size), 0))
).
I couldn't find any issue on this, I would be happy if anyone can help me with it.
Thanks for the question! I think the issue you're running into is related to the execution model of loops (fori_loop
, while_loop
, and scan
) on GPU. For GPU backends, each iteration effectively requires a kernel launch, so if you have very cheap iterations it can lead to a lot of overhead. On CPU, there is no such overhead.
On GPU I'd suggest doing this kind of conditional update using lax.select
or jnp.where
, which is basically a more flexible wrapper to lax.select
.
Thank you very much for the reply! Yes, on GPU jnp.where works way better than the fori_loop that basically does nothing for most of the iterations (if I was able to use "continue", maybe that would work but trueFun and falseFun have to return same type of things).
A related question to your reply. Does every fori_loop step use a different core of the same GPU? So, basically, if each loop iteration is cheap, it is more efficient to use jnp.vectorize?
Until now, I thought the problem was due to jax.lax.cond(). Thanks for the heads-up. For people having similar slow-down on GPU, here is a very simple code for comparison,
import jax
import jax.numpy as jnp
from jax.lax import fori_loop
@jax.jit
@jnp.vectorize
def add_vec(x):
return x+5
@jax.jit
def loop_for(x):
def add_for(i,x):
return x.at[i].add(5)
x = fori_loop(0, x.size, add_for, x)
return x
a = jnp.ones(100000)
b = jnp.ones(100000)
c = jnp.ones(100000)
%timeit _ = add_vec(a).block_until_ready()
%timeit _ = loop_for(b).block_until_ready()
%timeit _ = (c+5).block_until_ready()
45.5 µs ± 257 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
541 ms ± 3.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.8 µs ± 335 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
For reference, same code on CPU resultsin,
51.9 µs ± 77.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
172 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
54.5 µs ± 97 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
@jakevdp
For GPU backends, each iteration effectively requires a kernel launch
Is this a limitation of XLA or a fundamental limitation of the hardware itself? (I also asked this here.)
I don't know