jax
jax copied to clipboard
Pallas jax.lax.fori_loop over long inputs slows down
Description
Inside Pallas kernels, we often want a loop, and to speed up compilation, we typically use a scan function such as jax.lax.fori_loop. (For example, in the attention kernel example here.)
As the length of the loop grows, fori_loop slows down execution substantially (relative to using a Python for-loop). I put together a minimal script to isolate the issue, and running it on an A6000, saw a 2-3x slowdown on long loops:
T=256
python for-loop: compile = 177ms, execution ms_per_kernel_call = 0.317ms
jax.lax.fori_loop: compile = 193ms, execution ms_per_kernel_call = 0.318ms
T=2048
python for-loop: compile = 242ms, execution ms_per_kernel_call = 2.247ms
jax.lax.fori_loop: compile = 200ms, execution ms_per_kernel_call = 2.255ms
T=8192
python for-loop: compile = 473ms, execution ms_per_kernel_call = 8.946ms
jax.lax.fori_loop: compile = 194ms, execution ms_per_kernel_call = 9.281ms
T=16384
python for-loop: compile = 776ms, execution ms_per_kernel_call = 18.177ms
jax.lax.fori_loop: compile = 198ms, execution ms_per_kernel_call = 22.288ms
T=32768
python for-loop: compile = 1460ms, execution ms_per_kernel_call = 36.009ms
jax.lax.fori_loop: compile = 200ms, execution ms_per_kernel_call = 58.552ms
T=65536
python for-loop: compile = 2978ms, execution ms_per_kernel_call = 71.313ms
jax.lax.fori_loop: compile = 195ms, execution ms_per_kernel_call = 172.925ms
Here is the script that generated these results:
import time
import jax
import jax.experimental.pallas as pl
import jax.numpy as jnp
# ---------- PALLAS ------------
class JaxKernel:
fwd_blk_i = 128
fwd_blk_j = 64
def __init__(self, use_scan):
self.use_scan = use_scan
def __call__(self, X):
t, d = X.shape
grid = (t // self.fwd_blk_i,)
Y = pl.pallas_call(
self.fwd_kernel,
grid=grid,
out_shape=jax.ShapeDtypeStruct(X.shape, X.dtype),
)(X)
return Y
def fwd_kernel(self, X_ref, Y_ref):
i = pl.program_id(0)
t, d = X_ref.shape
X_i = pl.load(X_ref, pl.ds(start=i * self.fwd_blk_i, size=self.fwd_blk_i))
Y_i_acc = jnp.zeros([self.fwd_blk_i, d], dtype=X_i.dtype)
def body(j, carry):
B_ij = X_i.sum()
carry += B_ij + j # crashes if loop variable not involved
return carry
if self.use_scan:
Y_i = jax.lax.fori_loop(0, t // self.fwd_blk_j, body, Y_i_acc)
else:
for j in range(0, t // self.fwd_blk_j):
Y_i_acc = body(j, Y_i_acc)
Y_i = Y_i_acc
pl.store(Y_ref, pl.ds(start=i*self.fwd_blk_i, size=self.fwd_blk_i), Y_i)
# ------ BENCHMARK UTILS -------
def prepare_data(b, t, d, dtype):
"""Creates the data for a forward pass."""
return jax.random.normal(jax.random.PRNGKey(0), shape=(b, t, d), dtype=dtype)
def heavyweight(f, n, b, t, d, dtype):
"""Given a kernel for a single batch, transforms it to repeat on many batches.
Returns output to make sure no computation gets compiled away."""
@jax.jit
def heavy_f():
def batch_f(X):
out = f(X).mean()
return out
def scanner(_, __):
X = prepare_data(b, t, d, dtype)
return None, batch_f(X)
_, Y = jax.lax.scan(scanner, None, jnp.arange(n))
return Y
return heavy_f
# ---- RUN -----
def main():
# set up pallas att
pallas_scanless = jax.vmap(JaxKernel(use_scan=False))
pallas_scanner = jax.vmap(JaxKernel(use_scan=True))
# confirm output is correct
X = prepare_data(1, 2048, 16, jnp.float32)
pallas_scanless_Y = pallas_scanless(X)
pallas_scanner_Y = pallas_scanner(X)
assert jnp.allclose(pallas_scanner_Y, pallas_scanless_Y, atol=.001)
# choose hyperparameters
N = 16 # number of times to repeat execution
B = 512 # batch size
T = 8192 # context size
D = 64 # feature size
dtype = jnp.float16
print(f'{N=} {B=} {T=} {D=}')
# jit functions
heavy_scanless = heavyweight(pallas_scanless, N, B, T, D, dtype)
heavy_scanner = heavyweight(pallas_scanner, N, B, T, D, dtype)
# compile
_t = time.time()
jax.block_until_ready(heavy_scanless())
scanless_time_to_compile_and_execute = time.time() - _t
_t = time.time()
jax.block_until_ready(heavy_scanner())
scanner_time_to_compile_and_execute = time.time() - _t
# time the main execution
_t = time.time()
jax.block_until_ready(heavy_scanless())
scanless_time_to_execute = time.time() - _t
ms_per_kernel_call = 1000 * scanless_time_to_execute / N
scanless_time_to_compile = (
scanless_time_to_compile_and_execute - scanless_time_to_execute)
print(f'{1000*scanless_time_to_compile = :.3f}ms, '
f'execution {ms_per_kernel_call = :.3f}ms')
_t = time.time()
jax.block_until_ready(heavy_scanner())
scanner_time_to_execute = time.time() - _t
ms_per_kernel_call = 1000 * scanner_time_to_execute / N
scanner_time_to_compile = (
scanner_time_to_compile_and_execute - scanner_time_to_execute)
print(f'{1000*scanner_time_to_compile = :.3f}ms, '
f'execution {ms_per_kernel_call = :.3f}ms')
if __name__ == '__main__':
main()
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax; jax.print_environment_info()
jax: 0.4.25
jaxlib: 0.4.25
numpy: 1.26.4
python: 3.11.4 (main, Dec 7 2023, 15:43:41) [GCC 12.3.0]
jax.devices (2 total, 2 local): [cuda(id=0) cuda(id=1)]
process_count: 1
platform: uname_result(system='Linux', node='jacob-manifestai', release='6.2.0-39-generic', version='#40-Ubuntu SMP PREEMPT_DYNAMIC Tue Nov 14 14:18:00 UTC 2023', machine='x86_64')
$ nvidia-smi
Wed Apr 24 01:15:09 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.07 Driver Version: 535.161.07 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A6000 On | 00000000:2D:00.0 Off | Off |
| 30% 39C P2 38W / 300W | 269MiB / 49140MiB | 1% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA RTX A6000 On | 00000000:41:00.0 Off | Off |
| 30% 40C P2 30W / 300W | 269MiB / 49140MiB | 1% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 4176762 C ...anifest2-ozEuPuop-py3.11/bin/python 262MiB |
| 1 N/A N/A 4176762 C ...anifest2-ozEuPuop-py3.11/bin/python 262MiB |
+---------------------------------------------------------------------------------------+
I think the best explanation I found online is the following:
To elaborate on this, the reason GPU is so fast for vectorized operations is not that individual floating point operations are particularly fast (they're actually often slower than similar operations on a CPU!), but rather that it can very efficiently run many operations in parallel. For an operation like scan in which each step depends on the output of the previous, the sequence of operations as a whole cannot be parallelized. So you end up not taking advantage of any of the GPU's inherent parallelism, and the result is slow execution. Contrast this to CPU, where individual floating point operations are relatively fast, but there is no so much in-built parallelism available. Because of this, scan does not incur as much of a performance penalty.
I think this falls into the same bucket. Like my example, where the input of each layer depended on the output of the previous one (the carry), this is a pure sequential loop. There is probably a sweet spot with the unroll parameter where compilation times and loop times are optimal.