jax icon indicating copy to clipboard operation
jax copied to clipboard

Unable to build jaxlib with debug symbols for GPU

Open juuso-oskari opened this issue 5 months ago • 2 comments

Description

I try to build the jaxlib with debug symbols for the xla with the following command:

python build/build.py --enable_cuda --bazel_options=--override_repository=xla=/xla --bazel_options=--jobs=1 --bazel_options=--compilation_mode=dbg

The build goes fine all the way up till the end when it tries to link the xla_extension.so:

[1 / 2] Linking external/xla/xla/python/xla_extension.so; 102s local
ERROR: /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/external/xla/xla/python/BUILD:1257:21: Linking external/xla/xla/python/xla_extension.so failed: (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target @xla//xla/python:xla_extension.so) 
  (cd /root/.cache/bazel/_bazel_root/bee4ad1fd43279be7a03b33426e824d5/execroot/__main__ && \
  exec env - \
    LD_LIBRARY_PATH=/opt/amazon/efa/lib:/usr/local/cuda/compat/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64 \
    PATH=/opt/amazon/efa/bin:/usr/local/mpi/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/ucx/bin \
    PWD=/proc/self/cwd \
    TF_CUDA_COMPUTE_CAPABILITIES=sm_50,sm_60,sm_70,sm_80,compute_90 \
  external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc @bazel-out/k8-dbg/bin/external/xla/xla/python/xla_extension.so-2.params)
# Configuration: 5ca7bfb6889cfb0ee4db260a87240da2629ea60addd0c8f57834d31891b46935
# Execution platform: @local_execution_config_platform//:platform
collect2: fatal error: ld terminated with signal 9 [Killed]
compilation terminated.
[2 / 2] checking cached actions
Target //jaxlib/tools:build_wheel failed to build
INFO: Elapsed time: 162.700s, Critical Path: 160.90s
INFO: 2 processes: 2 internal.
FAILED: Build did NOT complete successfully

I don't get this error if I don't pass the argument --enable_cuda so probably it has something to do with the CUDA. But then again, passing --enable_cuda works if I don't try to build with debug symbols (so not passing --bazel_options=--compilation_mode=dbg).

System info (python version, jaxlib version, accelerator, etc.)

python version: Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] on linux jax: latest commit https://github.com/google/jax.git xla:latest commit https://github.com/openxla/xla accelerator (info from nvidia-smi): NVIDIA-SMI 560.28.03, Driver Version: 560.28.03, CUDA Version: 12.6, NVIDIA GeForce RTX 4080 Laptop GPU

To reproduce the error with a docker container:

# download local jax and xla repos
git clone https://github.com/openxla/xla.git
git clone https://github.com/google/jax
# start up nvidia docker container for jax
sudo docker run --name jax --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -it -d -v $PWD/jax:/jax -v $PWD/xla:/xla -v ~/.cache/bazel:/root/.cache/bazel nvcr.io/nvidia/jax:24.04-py3
# enter container in interactive mode
docker exec -it jax bash
# build jaxlib (produces the error)
cd jax
python build/build.py --enable_cuda --bazel_options=--override_repository=xla=/xla --bazel_options=--jobs=4 --bazel_options=--compilation_mode=dbg

juuso-oskari avatar Sep 11 '24 10:09 juuso-oskari