jax
jax copied to clipboard
segfault with `map` and `pure_callback`
Description
I get a segfault when running the following code snippet on my GPU machine.
import numpy as np
import jax
import jax.numpy as jnp
def f_host(x):
# call a numpy (not jax.numpy) operation:
return np.sin(x).astype(x.dtype)
def f(x):
result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jax.pure_callback(f_host, result_shape, x)
def f_scaled(x):
return 10 * f(x)
x = jnp.zeros((8, 256)) # Dummy input data
jax.lax.map(f_scaled, x) # fails
The error:
F0923 14:25:19.859902 47282 shape_tree.cc:54] Check failed: result->children_start_id >= 0 (-1 vs. 0)
*** Check failure stack trace: ***
@ 0x7f23f9c701f4 absl::lts_20230802::log_internal::LogMessage::SendToLog()
@ 0x7f23f9c700f4 absl::lts_20230802::log_internal::LogMessage::Flush()
@ 0x7f23f9c70599 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f23f6df4f11 xla::internal::IndexTable::operator[]()
@ 0x7f23f6b8eff4 xla::HloDataflowAnalysis::GetValueSet()
@ 0x7f23f68f1eec xla::BufferAssignment::GetUniqueSlice()
@ 0x7f23f64a346e xla::gpu::GetAllocationSlice()
@ 0x7f23f59a3355 xla::gpu::(anonymous namespace)::GetResultSlice()
@ 0x7f23f59a5c27 xla::ShapeUtil::ForEachSubshapeWithStatus<>()::{lambda()#1}::operator()()
@ 0x7f23f59a5a98 xla::ShapeUtil::ForEachMutableSubshapeWithStatusHelper<>()
@ 0x7f23f59a5b1f xla::ShapeUtil::ForEachMutableSubshapeWithStatusHelper<>()
@ 0x7f23f599f274 xla::gpu::DynamicSliceFusion::Emit()
@ 0x7f23f58dce84 xla::gpu::IrEmitterUnnested::EmitFusion()
@ 0x7f23f58e77f5 xla::gpu::IrEmitterUnnested::EmitHloInstruction()
@ 0x7f23f58c7f6e xla::gpu::IrEmitterUnnested::EmitHloComputation()
@ 0x7f23f58defa3 xla::gpu::IrEmitterUnnested::BuildWhileThunk()
@ 0x7f23f58ded36 xla::gpu::IrEmitterUnnested::EmitWhile()
@ 0x7f23f58e7776 xla::gpu::IrEmitterUnnested::EmitHloInstruction()
@ 0x7f23f58c7f6e xla::gpu::IrEmitterUnnested::EmitHloComputation()
@ 0x7f23f55e6eee xla::gpu::CompileModuleToLlvmIr()
@ 0x7f23f55b4ff6 xla::gpu::GpuCompiler::CompileToBackendResult()
@ 0x7f23f55b7cce xla::gpu::GpuCompiler::RunBackend()
@ 0x7f23f5389c92 xla::Service::BuildExecutable()
@ 0x7f23f5353415 xla::LocalService::CompileExecutables()
@ 0x7f23f5346ae4 xla::LocalClient::Compile()
@ 0x7f23f52ed9ea xla::PjRtStreamExecutorClient::Compile()
@ 0x7f23f528d821 xla::StreamExecutorGpuClient::Compile()
@ 0x7f23f52eeb6f xla::PjRtStreamExecutorClient::Compile()
@ 0x7f23f524902a std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke()
@ 0x7f23f5238d34 pjrt::PJRT_Client_Compile()
@ 0x7f2407bbe65d xla::InitializeArgsAndCompile()
@ 0x7f2407bbeb3e xla::PjRtCApiClient::Compile()
@ 0x7f240d49e6fc xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x7f240d499a01 xla::ifrt::PjRtCompiler::Compile()
@ 0x7f240cc6d25e xla::PyClient::CompileIfrtProgram()
@ 0x7f240cc6defb xla::PyClient::Compile()
@ 0x7f240cc75f1f nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
@ 0x7f240d476aed nanobind::detail::nb_func_vectorcall_complex()
@ 0x7f26bde6d548 nanobind::detail::nb_bound_method_vectorcall()
@ 0x556aed3cc7dc _PyEval_EvalFrameDefault
Aborted (core dumped)
I noticed that this does not occur if the dummy input data is smaller (say, (8, 10)). I wasn't able to provoke the error using jit
or pmap
. The error does go away if I set jax.config.update("jax_disable_jit", True)
.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.31
jaxlib: 0.4.31
numpy: 1.26.0
python: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
jax.devices (8 total, 8 local): [CudaDevice(id=0) CudaDevice(id=1) ... CudaDevice(id=6) CudaDevice(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='XXX', release='5.15.0-122-generic', version='#132~20.04.1-Ubuntu SMP Fri Aug 30 15:50:07 UTC 2024', machine='x86_64')
$ nvidia-smi
Mon Sep 23 14:23:24 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03 Driver Version: 560.35.03 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 2080 Ti Off | 00000000:1A:00.0 Off | N/A |
| 30% 30C P0 53W / 250W | 160MiB / 11264MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA GeForce RTX 2080 Ti Off | 00000000:1B:00.0 Off | N/A |
| 30% 34C P0 53W / 250W | 160MiB / 11264MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA GeForce RTX 2080 Ti Off | 00000000:1D:00.0 Off | N/A |
| 30% 36C P0 53W / 250W | 160MiB / 11264MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA GeForce RTX 2080 Ti Off | 00000000:1E:00.0 Off | N/A |
| 30% 31C P0 62W / 250W | 160MiB / 11264MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 4 NVIDIA GeForce RTX 2080 Ti Off | 00000000:3D:00.0 Off | N/A |
| 30% 29C P0 66W / 250W | 160MiB / 11264MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 5 NVIDIA GeForce RTX 2080 Ti Off | 00000000:3E:00.0 Off | N/A |
| 30% 30C P0 53W / 250W | 160MiB / 11264MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 6 NVIDIA GeForce RTX 2080 Ti Off | 00000000:3F:00.0 Off | N/A |
| 30% 29C P0 51W / 250W | 160MiB / 11264MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 7 NVIDIA GeForce RTX 2080 Ti Off | 00000000:41:00.0 Off | N/A |
| 30% 32C P0 51W / 250W | 160MiB / 11264MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 1 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 2 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 3 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 4 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 5 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 6 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
| 7 N/A N/A 47282 C ...iniconda3/envs/scico/bin/python3.10 156MiB |
+-----------------------------------------------------------------------------------------+