jax icon indicating copy to clipboard operation
jax copied to clipboard

segfault with `map` and `pure_callback`

Open Michael-T-McCann opened this issue 5 months ago • 1 comments

Description

I get a segfault when running the following code snippet on my GPU machine.

import numpy as np
import jax
import jax.numpy as jnp


def f_host(x):
    # call a numpy (not jax.numpy) operation:
    return np.sin(x).astype(x.dtype)


def f(x):
    result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
    return jax.pure_callback(f_host, result_shape, x)


def f_scaled(x):
    return 10 * f(x)


x = jnp.zeros((8, 256))  # Dummy input data
jax.lax.map(f_scaled, x)  # fails

The error:

F0923 14:25:19.859902   47282 shape_tree.cc:54] Check failed: result->children_start_id >= 0 (-1 vs. 0) 
*** Check failure stack trace: ***
    @     0x7f23f9c701f4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7f23f9c700f4  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7f23f9c70599  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f23f6df4f11  xla::internal::IndexTable::operator[]()
    @     0x7f23f6b8eff4  xla::HloDataflowAnalysis::GetValueSet()
    @     0x7f23f68f1eec  xla::BufferAssignment::GetUniqueSlice()
    @     0x7f23f64a346e  xla::gpu::GetAllocationSlice()
    @     0x7f23f59a3355  xla::gpu::(anonymous namespace)::GetResultSlice()
    @     0x7f23f59a5c27  xla::ShapeUtil::ForEachSubshapeWithStatus<>()::{lambda()#1}::operator()()
    @     0x7f23f59a5a98  xla::ShapeUtil::ForEachMutableSubshapeWithStatusHelper<>()
    @     0x7f23f59a5b1f  xla::ShapeUtil::ForEachMutableSubshapeWithStatusHelper<>()
    @     0x7f23f599f274  xla::gpu::DynamicSliceFusion::Emit()
    @     0x7f23f58dce84  xla::gpu::IrEmitterUnnested::EmitFusion()
    @     0x7f23f58e77f5  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7f23f58c7f6e  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7f23f58defa3  xla::gpu::IrEmitterUnnested::BuildWhileThunk()
    @     0x7f23f58ded36  xla::gpu::IrEmitterUnnested::EmitWhile()
    @     0x7f23f58e7776  xla::gpu::IrEmitterUnnested::EmitHloInstruction()
    @     0x7f23f58c7f6e  xla::gpu::IrEmitterUnnested::EmitHloComputation()
    @     0x7f23f55e6eee  xla::gpu::CompileModuleToLlvmIr()
    @     0x7f23f55b4ff6  xla::gpu::GpuCompiler::CompileToBackendResult()
    @     0x7f23f55b7cce  xla::gpu::GpuCompiler::RunBackend()
    @     0x7f23f5389c92  xla::Service::BuildExecutable()
    @     0x7f23f5353415  xla::LocalService::CompileExecutables()
    @     0x7f23f5346ae4  xla::LocalClient::Compile()
    @     0x7f23f52ed9ea  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f23f528d821  xla::StreamExecutorGpuClient::Compile()
    @     0x7f23f52eeb6f  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f23f524902a  std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
    @     0x7f23f5238d34  pjrt::PJRT_Client_Compile()
    @     0x7f2407bbe65d  xla::InitializeArgsAndCompile()
    @     0x7f2407bbeb3e  xla::PjRtCApiClient::Compile()
    @     0x7f240d49e6fc  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7f240d499a01  xla::ifrt::PjRtCompiler::Compile()
    @     0x7f240cc6d25e  xla::PyClient::CompileIfrtProgram()
    @     0x7f240cc6defb  xla::PyClient::Compile()
    @     0x7f240cc75f1f  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7f240d476aed  nanobind::detail::nb_func_vectorcall_complex()
    @     0x7f26bde6d548  nanobind::detail::nb_bound_method_vectorcall()
    @     0x556aed3cc7dc  _PyEval_EvalFrameDefault
Aborted (core dumped)

I noticed that this does not occur if the dummy input data is smaller (say, (8, 10)). I wasn't able to provoke the error using jit or pmap. The error does go away if I set jax.config.update("jax_disable_jit", True).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31
jaxlib: 0.4.31
numpy:  1.26.0
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='XXX', release='5.15.0-122-generic', version='#132~20.04.1-Ubuntu SMP Fri Aug 30 15:50:07 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Sep 23 14:23:24 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:1A:00.0 Off |                  N/A |
| 30%   30C    P0             53W /  250W |     160MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:1B:00.0 Off |                  N/A |
| 30%   34C    P0             53W /  250W |     160MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:1D:00.0 Off |                  N/A |
| 30%   36C    P0             53W /  250W |     160MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:1E:00.0 Off |                  N/A |
| 30%   31C    P0             62W /  250W |     160MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:3D:00.0 Off |                  N/A |
| 30%   29C    P0             66W /  250W |     160MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:3E:00.0 Off |                  N/A |
| 30%   30C    P0             53W /  250W |     160MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:3F:00.0 Off |                  N/A |
| 30%   29C    P0             51W /  250W |     160MiB /  11264MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA GeForce RTX 2080 Ti     Off |   00000000:41:00.0 Off |                  N/A |
| 30%   32C    P0             51W /  250W |     160MiB /  11264MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    1   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    2   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    3   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    4   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    5   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    6   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
|    7   N/A  N/A     47282      C   ...iniconda3/envs/scico/bin/python3.10        156MiB |
+-----------------------------------------------------------------------------------------+

Michael-T-McCann avatar Sep 23 '24 20:09 Michael-T-McCann