flax icon indicating copy to clipboard operation
flax copied to clipboard

Unexpected segmentation fault on GPU inside flax.linen.scan

Open Kajiih opened this issue 2 years ago • 2 comments

Related issue on JAX repo: https://github.com/google/jax/issues/17781

Solution

If updating cuda drivers is not possible, the flag XLA_FLAGS=--xla_gpu_graph_level=0 solves the segmentation fault issue.

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 16.04.7 (slurm server)
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: jaxlib: 0.4.16+cuda11.cudnn86, jax: 0.4.16, flax: 0.7.4
  • Python version: 3.11.5
  • GPU/TPU model and memory: Titan RTX 24220MiB
  • CUDA version (if applicable): 11.2

Problem you have encountered:

  • Segmentation fault on GPU when using flax.linen.scan, apparently while random values generated inside the loop body are used.
  • The error is very inconsistent, happens only on GPU and can happen or not given some hyper parameter values (from what I experienced)
  • I provide a somewhat minimal example to reproduce, I don't know if it is actually minimal since it sometime won't happen if I change the hyper parameters.
  • In the following code, the segmentation fault arises only when I call the second model, and also doesn't happen if I comment the x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5 line in the scan. It also depends on the hyper parameters used, I couldn't figure out why.

What you expected to happen:

No segmentation fault altogether

Logs, error messages, etc:

Consol output + error message:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1695583905.012097   68369 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-09-25 04:31:46.366898: W external/xla/xla/service/gpu/nvptx_compiler.cc:708] The NVIDIA driver's CUDA version is 11.2 which is older than the ptxas CUDA version (11.8.89). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
layer_size: 200; OK
Fatal Python error: Segmentation fault

Current thread 0x00007fdef8423700 (most recent call first):
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1149 in __call__
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/profiler.py", line 314 in wrapper
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1154 in _pjit_call_impl_python
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1198 in call_impl_cache_miss
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 1214 in _pjit_call_impl
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 821 in process_primitive
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 389 in bind_with_trace
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/core.py", line 2604 in bind
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 166 in _python_pjit_helper
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/pjit.py", line 255 in cache_miss
  File "/gallery_tate/julian/.miniconda3/envs/jax_test/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 177 in reraise_with_filtered_traceback
  File "/gallery_tate/julian/SD/proj/segfault_debug_script.py", line 89 in <module>

Extension modules: jaxlib.cpu_feature_guard, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, msgpack._cmsgpack, yaml._yaml (total: 16)
zsh: segmentation fault (core dumped) PYTHONFAULTHANDLER=1 python 

Steps to reproduce:

Here's a script that when ran, produces a segmentation fault regardless of the memory preallocation of JAX. On my machine, the error happens only for the smaller layer_size, which probably shows that it is not a memory issue. Also, when I comment the x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5 line, the segfault doesn't happen either.

import flax.linen as nn
import jax
import jax.numpy as jnp


class Model(nn.Module):
    decoder_1: nn.Module

    @nn.compact
    def __call__(self, key, x_1, x_2):
        def _process_slice(model, carry, inputs):
            x = inputs
            x_2, key = carry

            batch_size = x.shape[0]
            key, subkey = jax.random.split(key)
            x_2 *= jax.random.uniform(subkey, (batch_size, 1)) < 0.5  # If commented, no segfault

            x = jnp.concatenate([x, x_2], axis=-1)
            y_1 = model.decoder_1(x)
            y = y_1
            carry = x_2, key
            return carry, y

        # Call the function once to initialize the model parameters
        carry, y_0 = _process_slice(self, (x_2, key), x_1[0])

        # Scan over the remaining timesteps
        scan = nn.scan(
            _process_slice,
            variable_broadcast="params",
            out_axes=0,
        )
        _, y = scan(self, carry, x_1[1:])
        y = jnp.concatenate([y_0[None], y], axis=0)
        return y


class Decoder(nn.Module):
    hidden_size: int
    num_classes: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_size)(x)
        x = nn.relu(x)
        x = nn.Dense(self.num_classes)(x)
        return x


# %% Test 1st
layer_size = 200 # 16 = segfault, 200 = no segfault
num_layers = 4
num_classes = 10
temperature = 0.1
nb_timesteps = 10
batch_size = 4
dim_input = 6

decoder_1 = Decoder(layer_size, layer_size)
model = Model(decoder_1)

x_1 = jnp.ones((nb_timesteps, batch_size, dim_input))
x_2 = jnp.ones((batch_size, dim_input))

key = jax.random.PRNGKey(0)
key, sub_1, sub_2 = jax.random.split(key, 3)
params = model.init(sub_1, sub_2, x_1, x_2)
apply_fn = jax.jit(model.apply)

key, sub = jax.random.split(key)
y = apply_fn(params, sub, x_1, x_2)

print(f"layer_size: {layer_size}; OK")

# %% Test 2nd layer size
layer_size = 16 # 16 = segfault, 200 = no segfault

decoder_1 = Decoder(layer_size, layer_size)

model = Model(decoder_1)

key = jax.random.PRNGKey(0)
key, sub_1, sub_2 = jax.random.split(key, 3)
params = model.init(sub_1, sub_2, x_1, x_2)
apply_fn = jax.jit(model.apply)

key, sub = jax.random.split(key)
y = apply_fn(params, sub, x_1, x_2)

print(f"layer_size: {layer_size}; OK")

Kajiih avatar Sep 24 '23 19:09 Kajiih

I cannot reproduce this on CPU or TPU. Must be related to CUDA. I suggest you post this on the JAX repo as nn.scan is just a wrapper for jax.lax.scan

cgarciae avatar Sep 25 '23 21:09 cgarciae

For reference, it can be solved by using the flag XLA_FLAGS=--xla_gpu_graph_level=0 in case updating cuda drivers cannot be updated.

Kajiih avatar Sep 29 '23 11:09 Kajiih