tvm
tvm copied to clipboard
[Compile] accelerate compilation speed using NVRTC
This PR supports NVRTC as an alternative to NVCC for faster, device-side JIT compilation of CUDA kernels, in favor of the PR https://github.com/apache/tvm-ffi/pull/283.
It enhances the CUDA compilation backend by:
- Adding Python NVRTC support using cuda-python bindings
- Removing legacy C++ NVRTC fallback in favor of a Python-first approach
- Keeping nvcc as the default compiler with fatbin output (no behavior change for existing users)
Users can choose the compilation backend using an environment variable TVM_CUDA_COMPILE_MODE, choosing from "nvcc" and "nvrtc". For example,
TVM_CUDA_COMPILE_MODE=nvrtc python3 your_program.py
Here is a short benchmark of the compilation speed of kernels in test_target_codegen_cuda.py.
NVCC vs NVRTC Compilation Time Comparison (Python-side Call)
| Test Case | Code Size | NVCC Time (ms) | NVRTC Time (ms) | Speedup |
|---|---|---|---|---|
test_crossthread_reduction1 |
1945 B | 241.27 | 51.23 | 4.7x |
test_cuda_bf16_vectorize_add |
3760 B | 342.72 | 44.50 | 7.7x |
test_cuda_const_float_to_half |
12394 B | 272.85 | 31.99 | 8.5x |
test_cuda_device_func_call |
975 B | 215.58 | 21.47 | 10.0x |
test_cuda_float_const_hex_format |
685 B | 217.39 | 20.52 | 10.6x |
test_cuda_floordiv_with_vectorization |
1050 B | 213.88 | 23.32 | 9.2x |
test_cuda_inf_nan |
673 B | 214.33 | 24.94 | 8.6x |
test_cuda_tensormap |
755 B | 213.91 | 20.74 | 10.3x |
test_cuda_thread_sync_inside_condition |
1007 B | 213.43 | 28.29 | 7.5x |
test_cuda_vectorize_add |
908 B | 226.81 | 40.39 | 5.6x |
test_cuda_vectorize_load |
734 B | 217.25 | 24.02 | 9.0x |
test_device_host_call_same_func |
924 B | 216.03 | 21.21 | 10.2x |
test_vectorized_intrin1 |
847 B | 226.15 | 26.34 | 8.6x |