jax icon indicating copy to clipboard operation
jax copied to clipboard

Conditional array update on GPU using jnp.where vs fori_loop

Open YigitElma opened this issue 11 months ago • 2 comments

Description:

Hi! I am fairly new to using JAX. I have been trying to update 1 or more entries of a 1D array based on some condition inside a jax.jit and jnp.vectorize function. I managed to find a very fast way of doing this on CPU, however when I tested on GPU, it suffered from extreme slow down (around x1000).

For more context, the way which seems to be fastest on CPU is similar to this code,

import jax
from jax.lax import cond, fori_loop
import jax.numpy as jnp

@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)")
def fun_vectorized(r, m, n, dr):
    """
     m.size == n.size
     m and n may have repeating values
     N_max and M_max values are found using some functions
    """
    def update(i, args):
        """Updates the output if required."""
        alpha, N, result, out = args
        idx = jnp.where(jnp.logical_and(m[i] == alpha, n[i] == N), i, -1)
    
        def falseFun(args):
           """ Do nothing."""
            _, _, out = args
            return out
    
        def trueFun(args):
           """ Update the value at idx."""
            idx, result, out = args
            out = out.at[idx].set(result)
            return out
    
        out = cond(idx >= 0, trueFun, falseFun, (idx, result, out))
        return (alpha, N, result, out)    

    def body_inner(N, args):
        alpha, out = args
        result = find_some_array() # for this issue irrelevant

        # Update array
        _, _, _, out = fori_loop(0, m.size, update, (alpha, N, result, out))
        return (alpha, out)

    def body(alpha, out):
        # Find max value of n corresponding to alpha
        # This requires another function which is not relevant
        N_max = find_Nmax()
        # Loop over required n values
        _, out = fori_loop(
            0, (N_max + 1).astype(int), body_inner, (alpha, out)
        )
        return out

    out = jnp.zeros(m.size)
    M_max = jnp.max(m)
    # Loop over different m values
    out = fori_loop(0, (M_max + 1).astype(int), body, (out))

    return out

To be fair, only the update function is relevant to this issue, but I would like to include some other part of the code to show the architecture of my overall function. Since GPUs are not good at conditional branching, I changed my code to look like this,

@functools.partial(jnp.vectorize, excluded=(1, 2, 3), signature="()->(k)")
def fun_vectorized(r, m, n, dr):
    
    def body_inner(N, args):
        alpha, out = args
        result = find_some_array()

        # Find which indices to update
        mask = jnp.logical_and(m == alpha, n == N)
        out = jnp.where(mask, result, out)
        return (alpha, out)

    def body(alpha, out):
        N_max = find_Nmax()
        # Loop over required n values
        _, out = fori_loop(
            0, (N_max + 1).astype(int), body_inner, (alpha, out)
        )
        return out

    out = jnp.zeros(m.size)
    M_max = jnp.max(m)
    # Loop over different m values
    out = fori_loop(0, (M_max + 1).astype(int), body, (out))

    return out

I noticed that using update function with a fori_loop is really fast compared to jnp.where counterpart on CPU by 4-5 times. I would think they perform the same, but that is not the case. I can understand the slow down due to jax.lax.cond on GPU, but I was wondering if there is a better way of implementing what I am trying to do? And also, why is jnp.where with 3 arguments is slower than setting new values to an array? My guess is creating the copy of array instead of in-place update, but I am not sure.

Some other requirements for my application:

I need to use jnp.where or update functions because there can be 0 or multiple indices which satisfy the condition (for strictly 1 index case is way easier to handle using index = jnp.sum(jnp.where(jnp.logical_and(m == alpha, n == N), jnp.arange(m.size), 0)) ).

I couldn't find any issue on this, I would be happy if anyone can help me with it.

YigitElma avatar Feb 26 '24 03:02 YigitElma

Thanks for the question! I think the issue you're running into is related to the execution model of loops (fori_loop, while_loop, and scan) on GPU. For GPU backends, each iteration effectively requires a kernel launch, so if you have very cheap iterations it can lead to a lot of overhead. On CPU, there is no such overhead.

On GPU I'd suggest doing this kind of conditional update using lax.select or jnp.where, which is basically a more flexible wrapper to lax.select.

jakevdp avatar Feb 26 '24 20:02 jakevdp

Thank you very much for the reply! Yes, on GPU jnp.where works way better than the fori_loop that basically does nothing for most of the iterations (if I was able to use "continue", maybe that would work but trueFun and falseFun have to return same type of things).

A related question to your reply. Does every fori_loop step use a different core of the same GPU? So, basically, if each loop iteration is cheap, it is more efficient to use jnp.vectorize?

Until now, I thought the problem was due to jax.lax.cond(). Thanks for the heads-up. For people having similar slow-down on GPU, here is a very simple code for comparison,

import jax
import jax.numpy as jnp
from jax.lax import fori_loop

@jax.jit
@jnp.vectorize
def add_vec(x):
    return x+5

@jax.jit
def loop_for(x):
    
    def add_for(i,x):
        return x.at[i].add(5)
    
    x = fori_loop(0, x.size, add_for, x)
    return x

a = jnp.ones(100000)
b = jnp.ones(100000)
c = jnp.ones(100000)

%timeit _ = add_vec(a).block_until_ready()
%timeit _ = loop_for(b).block_until_ready()
%timeit _ = (c+5).block_until_ready()
45.5 µs ± 257 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
541 ms ± 3.14 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
83.8 µs ± 335 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For reference, same code on CPU resultsin,

51.9 µs ± 77.5 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
172 µs ± 151 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
54.5 µs ± 97 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

YigitElma avatar Feb 27 '24 03:02 YigitElma

@jakevdp

For GPU backends, each iteration effectively requires a kernel launch

Is this a limitation of XLA or a fundamental limitation of the hardware itself? (I also asked this here.)

carlosgmartin avatar Jul 11 '24 16:07 carlosgmartin

I don't know

jakevdp avatar Jul 11 '24 16:07 jakevdp