CUDA.jl
CUDA.jl copied to clipboard
Easier way to do mixed-mode matrix multiplication
Describe the bug In deep learning, people often use fp16 matmuls with fp32 accumulation (cuBLAS compute type) as a balance between performance and preserving numerical accuracy. In Torch, if you do a fp16 by fp16 matmul, fp32 compute type is the default behavior. In CUDA.jl the default is fp16 accumulation, and it doesn't seem to be possible to easily get fp32-accum behavior.
It would be great if there was either a toggle to change this behavior, similar to math_mode
, or maybe even to make the fp32-accum behavior the default.
Specifically, currently fp16 gemm! is dispatched to cublasHgemm whereas the suggested behavior (and the way Torch does it) is to dispatch to cublasSgemm but set the input/output datatype args to be fp16.
This also applies to batched matmuls, where CUDA.jl dispatches to cublasHgemmBatched, and maybe batched matvec products.
I'm happy to open a PR if the maintainers decide it's ok to change the current behavior without introducing a setting. If a setting is needed it might be better for someone more familiar with the project's structure to do this.
To reproduce
Use NSight Compute to see that the kernel used is ampere_h1688gemm_128x128_ldg8_stages_32x1_nn
or something with h1688
.
Version info
Details
julia> versioninfo() Julia Version 1.9.2 Commit e4ee485e909 (2023-07-05 09:39 UTC) Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 32 × 13th Gen Intel(R) Core(TM) i9-13900K WORD_SIZE: 64 LIBM: libopenlibm LLVM: libLLVM-14.0.6 (ORCJIT, goldmont) Threads: 1 on 32 virtual cores Environment: JULIA_EDITOR = code JULIA_NUM_THREADS =
julia> CUDA.versioninfo() CUDA runtime 12.1, artifact installation CUDA driver 12.2 NVIDIA driver 535.86.5
CUDA libraries:
- CUBLAS: 12.1.3
- CURAND: 10.3.2
- CUFFT: 11.0.2
- CUSOLVER: 11.4.5
- CUSPARSE: 12.1.0
- CUPTI: 18.0.0
- NVML: 12.0.0+535.86.5
Julia packages:
- CUDA: 4.4.0
- CUDA_Driver_jll: 0.5.0+1
- CUDA_Runtime_jll: 0.6.0+0
Toolchain:
- Julia: 1.9.2
- LLVM: 14.0.6
- PTX ISA support: 3.2, 4.0, 4.1, 4.2, 4.3, 5.0, 6.0, 6.1, 6.3, 6.4, 6.5, 7.0, 7.1, 7.2, 7.3, 7.4, 7.5
- Device capability support: sm_37, sm_50, sm_52, sm_53, sm_60, sm_61, sm_62, sm_70, sm_72, sm_75, sm_80, sm_86
1 device: 0: NVIDIA GeForce RTX 4090 (sm_89, 21.899 GiB / 23.988 GiB available)