jax
jax copied to clipboard
NVIDIA Jetson Orin Nano: no kernel image is available for execution on the device
Description
Hi, I have compiled jax from source and tried the prebuilt wheels, getting the same error. This is in a Docker container on the device. This is the error using the NVIDIA Jetson Orin Nano:
jax 0.4.31.dev20240728+6a7822a73 /home/jax jax-cuda12-pjrt 0.4.30 jax-cuda12-plugin 0.4.30 jaxlib 0.4.31.dev20240729
F0729 12:58:37.236422 109 stream_executor_util.cc:504] Could not create RepeatBufferKernel: INTERNAL: Failed call to cudaGetFuncBySymbol: no kernel image is available for execution on the device *** Check failure stack trace: *** @ 0xffff34829840 absl::lts_20230802::log_internal::LogMessage::SendToLog() @ 0xffff34829730 absl::lts_20230802::log_internal::LogMessage::Flush() @ 0xffff34829bc8 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0xffff3204dd6c xla::gpu::InitializeTypedBuffer<>() @ 0xffff32049cf0 xla::primitive_util::FloatingPointTypeSwitch<>() @ 0xffff32048484 xla::gpu::InitializeBuffer() @ 0xffff30d421b4 xla::gpu::AutotunerUtil::CreateBuffer() @ 0xffff30d410d8 xla::gpu::RedzoneBuffers::CreateInputs() @ 0xffff30d40d5c xla::gpu::RedzoneBuffers::FromInstruction() @ 0xffff30b28558 xla::gpu::GemmFusionAutotunerImpl::Profile() @ 0xffff30b2a1b0 xla::gpu::GemmFusionAutotunerImpl::Autotune() @ 0xffff30b2c00c xla::gpu::GemmFusionAutotuner::Run() @ 0xffff32583718 xla::HloPassPipeline::RunHelper() @ 0xffff325813fc xla::HloPassPipeline::RunPassesInternal<>() @ 0xffff32580eb4 xla::HloPassPipeline::Run() @ 0xffff30b82320 xla::gpu::GpuCompiler::OptimizeHloPostLayoutAssignment() @ 0xffff30985228 xla::gpu::NVPTXCompiler::OptimizeHloPostLayoutAssignment() @ 0xffff30b7dbc0 xla::gpu::GpuCompiler::OptimizeHloModule() @ 0xffff30b855c0 xla::gpu::GpuCompiler::RunHloPasses() @ 0xffff30969830 xla::Service::BuildExecutable() @ 0xffff309052cc xla::LocalService::CompileExecutables() @ 0xffff308f912c xla::LocalClient::Compile() @ 0xffff308b37b0 xla::PjRtStreamExecutorClient::Compile() @ 0xffff3085de5c xla::StreamExecutorGpuClient::Compile() @ 0xffff308b41f0 xla::PjRtStreamExecutorClient::Compile() @ 0xffff3082026c std::__detail::__variant::__gen_vtable_impl<>::__visit_invoke() @ 0xffff308118c0 pjrt::PJRT_Client_Compile() @ 0xffff6d63b850 xla::InitializeArgsAndCompile() @ 0xffff6d63bd58 xla::PjRtCApiClient::Compile() @ 0xffff6e4b5b80 xla::ifrt::PjRtLoadedExecutable::Create() @ 0xffff6e4a99e4 xla::ifrt::PjRtCompiler::Compile() @ 0xffff6e18d770 xla::PyClient::CompileIfrtProgram() @ 0xffff6e18e0a0 xla::PyClient::Compile() @ 0xffff6e18eb50 nanobind::detail::func_create<>()::{lambda()#1}::operator()() @ 0xffff6e51ccdc nanobind::detail::nb_func_vectorcall_complex() @ 0xffff76c557d8 nanobind::detail::nb_bound_method_vectorcall() @ 0xaaaae2c6d80c _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c6e2f8 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c6c7b0 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c6c7b0 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c70ca0 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xffff6e518a48 PyObject_Vectorcall @ 0xffff6e5193b8 nanobind::detail::obj_vectorcall() @ 0xffff6d5f261c nanobind::detail::api<>::operator()<>() @ 0xffff6d5f5ccc jax::WeakrefLRUCache::Call() @ 0xffff6d5f3c14 nanobind::detail::func_create<>()::{lambda()#1}::_FUN() @ 0xffff6e51ccdc nanobind::detail::nb_func_vectorcall_complex() @ 0xaaaae2c793b0 _PyObject_FastCallDictTstate @ 0xaaaae2c8fe58 _PyObject_Call_Prepend @ 0xaaaae2dc1f10 (unknown) @ 0xaaaae2c7a180 _PyObject_MakeTpCall @ 0xaaaae2c70ff4 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c93b64 PyObject_Call @ 0xaaaae2c6e2f8 _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c6c8ec _PyEval_EvalFrameDefault @ 0xaaaae2c84828 _PyFunction_Vectorcall @ 0xaaaae2c93b64 PyObject_Call Fatal Python error: Aborted
System info (python version, jaxlib version, accelerator, etc.)
Device Jetson Orin Nano 5.15.136-tegra No LSB modules are available. Distributor ID: Ubuntu Description: Ubuntu 22.04.4 LTS Release: 22.04 Codename: jammy
Package: nvidia-jetpack Source: nvidia-jetpack (6.0) Version: 6.0+b106 Architecture: arm64 Maintainer: NVIDIA Corporation Installed-Size: 194 Depends: nvidia-jetpack-runtime (= 6.0+b106), nvidia-jetpack-dev (= 6.0+b106) Homepage: http://developer.nvidia.com/jetson Priority: standard Section: metapackages Filename: pool/main/n/nvidia-jetpack/nvidia-jetpack_6.0+b106_arm64.deb Size: 29296 SHA256: 561d38f76683ff865e57b2af41e303be7e590926251890550d2652bdc51401f8 SHA1: ef3fca0c1b5c780b2bad1bafae6437753bd0a93f MD5sum: 95de21b4fce939dee11c6df1f2db0fa5 Description: NVIDIA Jetpack Meta Package Description-md5: ad1462289bdbc54909ae109d1d32c0a8
Package: nvidia-jetpack Source: nvidia-jetpack (6.0) Version: 6.0+b87 Architecture: arm64 Maintainer: NVIDIA Corporation Installed-Size: 194 Depends: nvidia-jetpack-runtime (= 6.0+b87), nvidia-jetpack-dev (= 6.0+b87) Homepage: http://developer.nvidia.com/jetson Priority: standard Section: metapackages Filename: pool/main/n/nvidia-jetpack/nvidia-jetpack_6.0+b87_arm64.deb Size: 29298 SHA256: 70be95162aad864ee0b0cd24ac8e4fa4f131aa97b32ffa2de551f1f8f56bc14e SHA1: 36926a991855b9feeb12072694005c3e7e7b3836 MD5sum: 050cb1fd604a16200d26841f8a59a038 Description: NVIDIA Jetpack Meta Package Description-md5: ad1462289bdbc54909ae109d1d32c0a8
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Tue_Aug_15_22:08:11_PDT_2023 Cuda compilation tools, release 12.2, V12.2.140 Build cuda_12.2.r12.2/compiler.33191640_0
NVRM version: NVIDIA UNIX Open Kernel Module for aarch64 540.3.0 Release Build (buildbrain@mobile-u64-6367-d8000) Mon May 6 10:21:04 PDT 2024
Using cuDnn verison 8.9.4.25_1
Same problem.
I actually found the solution. I was able to compile it using the correct compute capabilities:
python3 build/build.py --enable_cuda --cuda_path /usr/local/cuda-12.2 --cudnn_path /usr/lib/aarch64-linux-gnu --cuda_version 12.2 --cudnn_version 9 --cuda_compute_capabilities sm_87
But in general I had a lot of troubles getting the compile process to work. I think I used v.4.28 with the solution of #22155
Okay, thank you for your suggestion. I will try it. Can you tell me how you found out the right parameter for the cuda_compute_capabilities?
It is listed here: https://developer.nvidia.com/cuda-gpus
python3 build/build.py --python_version=$PYTHON_VERSION --enable_cuda --cuda_compute_capabilities sm_87 \
--bazel_options=--repo_env=LOCAL_CUDA_PATH="/usr/local/cuda-12.2" \
--bazel_options=--repo_env=LOCAL_CUDNN_PATH="/usr/lib/aarch64-linux-gnu"
new versions of jax
I am in s similar situation. For two days i have been trying to get a working jax with GPU support up and running on a Jetson Orin Nano. I'm lost. I tried building jax from source, using various approaches, also specifying --cuda_version and --cudnn_version that match my setup, but it always fails with the 'no kernel image is available..." error when i run a training. Is there anyone who got this working?
Someone had JAX work on Jetson and has information here: https://github.com/dusty-nv/jetson-containers/tree/master/packages/ml/jax
Someone had JAX work on Jetson and has information here: https://github.com/dusty-nv/jetson-containers/tree/master/packages/ml/jax
Thanks, i'll have a look!