xla icon indicating copy to clipboard operation
xla copied to clipboard

Issue when trying to build extension from source with CUDA 12.8

Open cjfreeze opened this issue 9 months ago • 6 comments

Trying to build extension from source so that I can use bumblebee with the new RTX 5XXX series GPUs.

Reproduce

Tried to build by running XLA_BUILD=true XLA_TARGET=cuda mix compile after cleaning the xla dependency.

In a project with the following dependencies:

[
  {:bumblebee, "~> 0.6.0"},
  {:nx, "~> 0.9.0"},
  {:exla, "~> 0.9.0"}
]

Encountered the following error:

ERROR: /home/chris/.cache/xla_extension/xla-fd58925adee147d38c25a085354e15427a12d00a/xla/service/gpu/BUILD:1119:23: Compiling xla/service/gpu/cub_sort_kernel.cu.cc failed: (Exit 2): crosstool_wrapper_driver_is_not_gcc failed: error executing command (from target //xla/service/gpu:cub_sort_kernel_f16) external/local_config_cuda/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -MD -MF bazel-out/k8-opt/bin/xla/service/gpu/_objs/cub_sort_kernel_f16/cub_sort_kernel.cu.pic.d ... (remaining 85 arguments skipped)
nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
nvcc warning : Support for offline compilation for architectures prior to '<compute/sm/lto>_75' will be removed in a future release (Use -Wno-deprecated-gpu-targets to suppress warning).
./xla/service/gpu/gpu_prim.h(48): error: no instance of overloaded function "cub::ThreadLoadVolatilePointer" matches the specified type
  __attribute__((device)) __inline__ __attribute__((always_inline)) Eigen::half ThreadLoadVolatilePointer<Eigen::half>(
                                                                                ^

./xla/service/gpu/gpu_prim.h(63): error: no instance of overloaded function "cub::ThreadLoadVolatilePointer" matches the specified type
  ThreadLoadVolatilePointer<tsl::bfloat16>(tsl::bfloat16 *ptr,
  ^

2 errors detected in the compilation of "xla/service/gpu/cub_sort_kernel.cu.cc".
Target //xla/extension:xla_extension failed to build
Use --verbose_failures to see the command lines of failed build steps.
INFO: Elapsed time: 149.076s, Critical Path: 47.28s
INFO: 6287 processes: 4224 internal, 2063 local.
FAILED: Build did NOT complete successfully
make: *** [Makefile:26: /home/chris/.cache/xla/0.8.0/build/xla_extension-0.8.0-x86_64-linux-gnu-cuda.tar.gz] Error 1

Environment

Ubuntu 22.04 Nvidia Driver Version: 570.124.06 CUDA Version: 12.8

Notes

it is worth noting that I was trying to do this in order to bypass another issue I was getting:

(RuntimeError) ptxas exited with non-zero error code 65280, output: ptxas fatal   : Program with .target 'sm_90a' cannot be compiled to future architecture

Which I believe is caused by the presence of a 5XXX GPU in my system.

cjfreeze avatar Mar 25 '25 18:03 cjfreeze

Downgraded to CUDA 12.4, same build error.

cjfreeze avatar Mar 25 '25 19:03 cjfreeze

What cudnn version are you using? I'm a bit surprised this doesn't work out of the box with the 50xx series

polvalente avatar Mar 25 '25 19:03 polvalente

@cjfreeze the best bet may be to see if the build works on the latest XLA, you can try building with this env var OPENXLA_GIT_REV=b8bffdaa852ddb5a435d800331e916cc0e52857b. Note that in order to actually upgrade, we likely need to adjust EXLA code, but knowing if that helps would be a good indicator to do that :)

jonatanklosko avatar Mar 26 '25 03:03 jonatanklosko

What cudnn version are you using? I'm a bit surprised this doesn't work out of the box with the 50xx series

cuDNN: 90.8.0 I believe

@cjfreeze the best bet may be to see if the build works on the latest XLA, you can try building with this env var OPENXLA_GIT_REV=b8bffdaa852ddb5a435d800331e916cc0e52857b. Note that in order to actually upgrade, we likely need to adjust EXLA code, but knowing if that helps would be a good indicator to do that :)

I should have mentioned that I did try to rebuild on a variety of openXLA's commit hashes, and they all failed, pretty much every one failed due to some error related to missing files in the @TSL third party dependency. I didn't mention it originally because I wasn't sure if there was a reason why you were targeting that specific hash beyond that just being the latest at the time you wrote the buildfile or something.

The entire reason I was trying to build this from source in the first place was due to the hardware target SM_120 not being present in the autogenerated files created by this extension and used by the Elixir dep EXLA, which I believe is why I can't run Bumblebee models on my 5090 that I got. I was hoping that building from source would fix that.

cjfreeze avatar Mar 26 '25 19:03 cjfreeze

due to the hardware target SM_120 not being present in the autogenerated files

@cjfreeze we build with the following flag, which may be related:

https://github.com/elixir-nx/xla/blob/e7f24308aa27c75ed5ad44e2b11e9d134a0f4016/lib/xla.ex#L314-L315

We use whatever Jax uses though.

Another thing that you can try is run any computation using latest Jax, for example:

pip install -U "jax[cuda12]"
import jax
import jax.numpy as jnp

def f(x):
  return jnp.sum(x)

x = jnp.array([1, 2, 3])

print(x.device)

print(jax.jit(f)(x))

If you run into issues then that's most likely a limitation with the current XLA, and opening an issue in Jax gives the best chance of it being addressed.

jonatanklosko avatar Mar 27 '25 02:03 jonatanklosko

due to the hardware target SM_120 not being present in the autogenerated files

@cjfreeze we build with the following flag, which may be related:

xla/lib/xla.ex

Lines 314 to 315 in e7f2430

See https://github.com/google/jax/blob/66a92c41f6bac74960159645158e8d932ca56613/.bazelrc#L68

~s/--config=cuda --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90"/

Ah, yes, when I build other libraries for my 5xxx series I have to include sm_120 in that list. XLA definitely supports it, as long as you build for it.

cjfreeze avatar Mar 27 '25 19:03 cjfreeze

Do I understand correctly that Blackwell cards aren't supported now? I faced the same problem trying to run Nx on 5080.

petrkozorezov avatar May 23 '25 12:05 petrkozorezov

We have work in progress on updating XLA, which should help with this. It is mostly working, but we have to figure out how to keep supporting different CUDA versions

polvalente avatar May 27 '25 10:05 polvalente

I've just released xla v0.9.0 with updated XLA revision and precompiled with support for RTX 50xx series. Note that it requires changes in EXLA also, I opened a PR https://github.com/elixir-nx/nx/pull/1614.

jonatanklosko avatar Jun 16 '25 13:06 jonatanklosko