Unexpected segmentation fault on GPU inside flax.linen.scan
Related issue on JAX repo: https://github.com/google/jax/issues/17781
Solution
If updating cuda drivers is not possible, the flag XLA_FLAGS=--xla_gpu_graph_level=0 solves the segmentation fault issue.
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04.7 (slurm server)
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: jaxlib: 0.4.16+cuda11.cudnn86, jax: 0.4.16, flax: 0.7.4 - Python version: 3.11.5
- GPU/TPU model and memory: Titan RTX 24220MiB
- CUDA version (if applicable): 11.2
Problem you have encountered:
- Segmentation fault on GPU when using
flax.linen.scan, apparently while random values generated inside the loop body are used. - The error is very inconsistent, happens only on GPU and can happen or not given some hyper parameter values (from what I experienced)
- I provide a somewhat minimal example to reproduce, I don't know if it is actually minimal since it sometime won't happen if I change the hyper parameters.
- In the following code, the segmentation fault arises only when I call the second model, and also doesn't happen if I comment the
x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5line in thescan. It also depends on the hyper parameters used, I couldn't figure out why.
What you expected to happen:
No segmentation fault altogether
Logs, error messages, etc:
Consol output + error message:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695583905.012097 68369 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-09-25 04:31:46.366898: W external/xla/xla/service/gpu/nvptx_compiler.cc:708] The NVIDIA driver's CUDA version is 11.2 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
layer_size: 200; OK
Fatal Python error: Segmentation fault
Current thread 0x00007fdef8423700 (most recent call first):
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1149 in __call__
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/profiler.py", line 314 in wrapper
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1154 in _pjit_call_impl_python
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1198 in call_impl_cache_miss
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1214 in _pjit_call_impl
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 821 in process_primitive
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 389 in bind_with_trace
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 2604 in bind
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 166 in _python_pjit_helper
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 255 in cache_miss
File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 177 in reraise_with_filtered_traceback
File "/gallery_tate/julian/SD/proj/segfault_debug_script.py", line 89 in <module>
Extension modules: jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, msgpack._cmsgpack, yaml._yaml (total: 16)
zsh: segmentation fault (core dumped) PYTHONFAULTHANDLER=1 python
Steps to reproduce:
Here's a script that when ran, produces a segmentation fault regardless of the memory preallocation of JAX.
On my machine, the error happens only for the smaller layer_size, which probably shows that it is not a memory issue.
Also, when I comment the x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5 line, the segfault doesn't happen either.
import flax.linen as nn
import jax
import jax.numpy as jnp
class Model(nn.Module):
decoder_1: nn.Module
@nn.compact
def __call__(self, key, x_1, x_2):
def _process_slice(model, carry, inputs):
x = inputs
x_2, key = carry
batch_size = x.shape[0]
key, subkey = jax.random.split(key)
x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5 # If commented, no segfault
x = jnp.concatenate([x, x_2], axis=-1)
y_1 = model.decoder_1(x)
y = y_1
carry = x_2, key
return carry, y
# Call the function once to initialize the model parameters
carry, y_0 = _process_slice(self, (x_2, key), x_1[0])
# Scan over the remaining timesteps
scan = nn.scan(
_process_slice,
variable_broadcast="params",
out_axes=0,
)
_, y = scan(self, carry, x_1[1:])
y = jnp.concatenate([y_0[None], y], axis=0)
return y
class Decoder(nn.Module):
hidden_size: int
num_classes: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden_size)(x)
x = nn.relu(x)
x = nn.Dense(self.num_classes)(x)
return x
# %% Test 1st
layer_size = 200 # 16 = segfault, 200 = no segfault
num_layers = 4
num_classes = 10
temperature = 0.1
nb_timesteps = 10
batch_size = 4
dim_input = 6
decoder_1 = Decoder(layer_size, layer_size)
model = Model(decoder_1)
x_1 = jnp.ones((nb_timesteps, batch_size, dim_input))
x_2 = jnp.ones((batch_size, dim_input))
key = jax.random.PRNGKey(0)
key, sub_1, sub_2 = jax.random.split(key, 3)
params = model.init(sub_1, sub_2, x_1, x_2)
apply_fn = jax.jit(model.apply)
key, sub = jax.random.split(key)
y = apply_fn(params, sub, x_1, x_2)
print(f"layer_size: {layer_size}; OK")
# %% Test 2nd layer size
layer_size = 16 # 16 = segfault, 200 = no segfault
decoder_1 = Decoder(layer_size, layer_size)
model = Model(decoder_1)
key = jax.random.PRNGKey(0)
key, sub_1, sub_2 = jax.random.split(key, 3)
params = model.init(sub_1, sub_2, x_1, x_2)
apply_fn = jax.jit(model.apply)
key, sub = jax.random.split(key)
y = apply_fn(params, sub, x_1, x_2)
print(f"layer_size: {layer_size}; OK")
I cannot reproduce this on CPU or TPU. Must be related to CUDA.
I suggest you post this on the JAX repo as nn.scan is just a wrapper for jax.lax.scan
For reference, it can be solved by using the flag XLA_FLAGS=--xla_gpu_graph_level=0 in case updating cuda drivers cannot be updated.