Pallas kernel crash at `llo::CouldLtS32` when `interpret=False` on TPU
Description
Hello, I'm running into a core dump when writing TPU kernels. I was testing with interpret on, and the kernel was working. Without it, I get a core dump. Any temporary fix is appreciated!
def matmul_bias_gelu_kernel(x_ref, y_ref, b_ref, z_ref, acc_ref, *, nsteps):
@pl.when(pl.program_id(axis=4) == 0)
def _():
acc_ref[...] = jnp.zeros_like(acc_ref)
print(f"{ acc_ref[...].shape =}", f"{ x_ref[...].shape =}", f"{ y_ref[...].shape =}", f"{b_ref[...].shape = }")
acc_ref[...] += jnp.matmul(
x_ref[...].squeeze(axis=(0, 1)), y_ref[...].squeeze(axis=(0,)), preferred_element_type=jnp.float32
) + b_ref[...]
@pl.when(pl.program_id(axis=4) == nsteps - 1)
def _():
z_ref[...] = gelu(acc_ref[...].astype(z_ref.dtype))
@functools.partial(jax.jit, static_argnames=['bm', 'bk', 'bn', 'bl', 'bnc'])
def matmul_bias_gelu(
x: jax.Array,
y: jax.Array,
b: jax.Array,
*,
bl: int = 1,
bm: int = 16,
bk: int = 64,
bn: int = 128,
bnc: int = 1,
):
"""Compute gelu(x @ y + b)."""
nc, l, m, k = x.shape
l2, k2, n = y.shape
one, l3, _one, n2 = b.shape
assert l == l2 == l3 and k == k2 and one == 1 and _one == 1 and n == n2, f'Invalid dims {x.shape=} {y.shape=} {b.shape=}'
assert l % bl == 0 and m % bm == 0 and k % bk == 0 and nc % bnc == 0, 'Block sizes must be multiples of dims'
return pl.pallas_call(
functools.partial(matmul_bias_gelu_kernel, nsteps=k // bk),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[
pl.BlockSpec((bnc, bl, bm, bk), lambda nc, l, i, j, k: (nc, l, i, k)),
pl.BlockSpec((bl, bk, bn), lambda nc, l, i, j, k: (l, k, j)),
pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (1, l, 1, j)),
],
out_specs=pl.BlockSpec((bnc, bl, bm, bn), lambda nc, l, i, j, k: (nc, l, i, j)),
scratch_shapes=[pltpu.VMEM((bnc, bl, bm, bn), jnp.float32)],
grid=(nc // bnc, l // bl, m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((nc, l, m, n), x.dtype),
compiler_params=dict(mosaic=dict(
dimension_semantics=("parallel", "parallel", "parallel", "parallel", "arbitrary"))),
# interpret=True, # doesn't work with interpret=False
)(x, y, b)
I ran with
key = random.PRNGKey(42)
BS, NH, L, HF = 16, 32, 16 * 128, 64
CS = 16
NC = L // CS
HF_prime = 4 * HF
XB = random.normal(key, (NC, BS*NH, CS, HF))
W1 = random.normal(key, (BS*NH, HF, HF_prime))
b1 = random.normal(key, (1, BS*NH, 1, HF_prime))
output_ref = matmul_bias_gelu_ref(XB, W1, b1)
output_kernel = matmul_bias_gelu(XB, W1, b1)
print('max abs error', jnp.max(jnp.abs(output_ref - output_kernel)))
print(f"{output_ref.shape = } {output_kernel.shape = }")
assert jnp.allclose(output_ref, output_kernel, atol=1e-2)
When interpret=True, the assertion passes. When interpret=False, I get a core dump:
F0821 21:26:33.618965 402385 math_util.cc:68] Check failed: llo::CouldLtS32(digits[i], bounds[i])
*** Check failure stack trace: ***
@ 0x7f021a32ebe4 (unknown)
@ 0x7f021a32e6e8 (unknown)
@ 0x7f021a32f009 (unknown)
@ 0x7f021301eb4d (unknown)
@ 0x7f02130122ec (unknown)
@ 0x7f021301287b (unknown)
@ 0x7f021300ad7e (unknown)
@ 0x7f02130017f1 (unknown)
@ 0x7f0212ffe478 (unknown)
@ 0x7f02109c1908 (unknown)
@ 0x7f02109bd218 (unknown)
@ 0x7f02109b1a53 (unknown)
@ 0x7f0210999578 (unknown)
@ 0x7f02109b1dd9 (unknown)
@ 0x7f02109b61ce (unknown)
@ 0x7f02109b9307 (unknown)
@ 0x7f0219f3499b (unknown)
@ 0x7f0219f3b224 (unknown)
@ 0x7f0219f44045 (unknown)
@ 0x7f021a1fce53 (unknown)
@ 0x7f02cb494ac3 (unknown)
https://symbolize.stripped_domain/r/?trace=7f021a32ebe4,7f021a32e6e7,7f021a32f008,7f021301eb4c,7f02130122eb,7f021301287a,7f021300ad7d,7f02130017f0,7f0212ffe477,7f02109c1907,7f02109bd217,7f02109b1a52,7f0210999577,7f02109b1dd8,7f02109b61cd,7f02109b9306,7f0219f3499a,7f0219f3b223,7f0219f44044,7f021a1fce52,7f02cb494ac2&map=
https://symbolize.stripped_domain/r/?trace=7f02cb4969fc,7f02cb44251f&map=
*** SIGABRT received by PID 401472 (TID 402385) on cpu 12 from PID 401472; ***
E0821 21:26:33.654562 402385 coredump_hook.cc:316] RAW: Remote crash data gathering hook invoked.
E0821 21:26:33.654581 402385 coredump_hook.cc:355] RAW: Skipping coredump since rlimit was 0 at process start.
E0821 21:26:33.654589 402385 client.cc:269] RAW: Coroner client retries enabled, will retry for up to 30 sec.
E0821 21:26:33.654595 402385 coredump_hook.cc:411] RAW: Sending fingerprint to remote end.
E0821 21:26:33.654622 402385 coredump_hook.cc:420] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] stat failed on crash reporting socket /var/google/services/logmanagerd/remote_coredump.socket (Is the listener running?): No such file or directory
E0821 21:26:33.654630 402385 coredump_hook.cc:472] RAW: Dumping core locally.
F0821 21:26:33.618965 402385 math_util.cc:68] Check failed: llo::CouldLtS32(digits[i], bounds[i])
E0821 21:26:33.900571 402385 process_state.cc:805] RAW: Raising signal 6 with default behavior
Aborted (core dumped)
System info (python version, jaxlib version, accelerator, etc.)
>>> import jax
>>> jax.print_environment_info()
jax: 0.4.31
jaxlib: 0.4.31
numpy: 2.1.0
python: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-187c2405-w-0', release='5.19.0-1022-gcp', version='#24~22.04.1-Ubuntu SMP Sun Apr 23 09:51:08 UTC 2023', machine='x86_64')
Pinging @sharadmv who will know best.
I think there is an out-of-bounds bug in the kernel that you wrote, which is hitting a runtime bounds check.
Specifically, the block spec for b:
pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (1, l, 1, j)),
should be:
pl.BlockSpec((1, bl, 1, bn), lambda nc, l, i, j, k: (0, l, 0, j)),
I think we could catch this error in interpret mode if we use checkify to look for OOB indexing.
If you run this kernel with the new TPU Interpret Mode (by passing interpret=pltpu.InterpretParams() to pl.pallas_call), it will flag the out-of-bounds block index:
IndexError: Out-of-bounds block index (1, 0, 1, 0) for input "b_ref"
in iteration (0, 0, 0, 0, 0) on device 0 (core 0):
reading [(slice(1, 2, 1), slice(0, 1, 1), slice(1, 2, 1), slice(0, 128, 1))]
but input has shape (1, 512, 1, 256)