CUDA.jl icon indicating copy to clipboard operation
CUDA.jl copied to clipboard

WMMA TensorFloat32 (TF32)

Open carstenbauer opened this issue 2 years ago • 24 comments

In this PR I will try to add TensorFloat32 (TF32) support for WMMAs. I'm surely the wrong person for the job, but let's see where things will take me / us 😄

Some resources:

  • TF32 format: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ and https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma-altfp
  • Element types and matrix sizes: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#wmma-type-sizes
  • PTX: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions

cc: @HenriDeh @thomasfaingnaert

carstenbauer avatar Mar 01 '22 20:03 carstenbauer

The initial naive attempt didn't work. For the following test kernel,

function kernel_wmma_tf32_lowlevel(a_dev, b_dev, c_dev, d_dev)
    a_frag = WMMA.llvm_wmma_load_a_col_m16n16k8_global_stride_tf32(pointer(a_dev), 16)
    b_frag = WMMA.llvm_wmma_load_b_col_m16n16k8_global_stride_tf32(pointer(b_dev), 8)
    c_frag = WMMA.llvm_wmma_load_c_col_m16n16k8_global_stride_f32(pointer(c_dev), 16)

    d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k8_tf32_tf32(a_frag, b_frag, c_frag)

    WMMA.llvm_wmma_store_d_col_m16n16k8_global_stride_f32(pointer(d_dev), d_frag, 16)
    return nothing
end

function call_kernel()
    m = n = 16
    k = 8
    dtype_a = dtype_b = Float32
    dtype_c = dtype_d = Float32

    d_a = CUDA.rand(dtype_a, m, k)
    d_b = CUDA.rand(dtype_b, k, n)
    d_c = CUDA.rand(dtype_c, m, n)
    d_d = CUDA.zeros(dtype_d, m, n)

    CUDA.@sync @cuda kernel_wmma_tf32_lowlevel(d_a, d_b, d_c, d_d)
    return nothing
end

I get

julia> call_kernel()                                                                                                                                    
ERROR: InvalidIRError: compiling kernel kernel_wmma_tf32_lowlevel(CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, Cu
DeviceMatrix{Float32, 1}) resulted in invalid LLVM IR                                                                                                   
Reason: unsupported call to an unknown function (call to llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8)                                           
Stacktrace:                                                                                                                                             
 [1] llvm_wmma_load_a_col_m16n16k8_global_stride_tf32                                                                                                   
   @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/device/intrinsics/wmma.jl:214                                               
 [2] kernel_wmma_tf32_lowlevel                                                                                                                          
   @ ./REPL[2]:2                                                                                                                                        
Reason: unsupported call to an unknown function (call to llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8)                                           
Stacktrace:                                                                                                                                             
 [1] llvm_wmma_load_b_col_m16n16k8_global_stride_tf32                                                                                                   
   @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/device/intrinsics/wmma.jl:214                                               
 [2] kernel_wmma_tf32_lowlevel                                                                                                                          
   @ ./REPL[2]:3                                                                                                                                        
Reason: unsupported call to an unknown function (call to llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32.p1i8)                                            
Stacktrace:                                                                                                                                             
 [1] llvm_wmma_load_c_col_m16n16k8_global_stride_f32                                                                                                    
   @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/device/intrinsics/wmma.jl:214                                               
 [2] kernel_wmma_tf32_lowlevel                                                                                                                          
   @ ./REPL[2]:4                                                                                                                                        
Reason: unsupported use of an undefined name (use of 'llvm_wmma_mma_col_col_m16n16k8_tf32_tf32')                                                        
Stacktrace:                                                                                                                                             
 [1] getproperty                                                                                                                                        
   @ ./Base.jl:35                                                                                                                                       
 [2] kernel_wmma_tf32_lowlevel                                                                                                                          
   @ ./REPL[2]:6
Reason: unsupported dynamic function invocation
Stacktrace:
 [1] kernel_wmma_tf32_lowlevel
   @ ./REPL[2]:6
Reason: unsupported dynamic function invocation (call to llvm_wmma_store_d_col_m16n16k8_global_stride_f32)
Stacktrace:
 [1] kernel_wmma_tf32_lowlevel
   @ ./REPL[2]:8
HINT: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code
Stacktrace:
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(kernel_wmma_tf32_lowlevel), NTuple{4, CuDeviceMatrix{Float32, 1}}}}, args::LLVM.Module)
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/validation.jl:119
  [2] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:327 [inlined]
  [3] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:325 [inlined]
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob)
    @ CUDA /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:326
  [7] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/cache.jl:90
  [8] cufunction(f::typeof(kernel_wmma_tf32_lowlevel), tt::Type{NTuple{4, CuDeviceMatrix{Float32, 1}}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ CUDA /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:297
  [9] cufunction
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:291 [inlined]
 [10] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:102 [inlined]
 [11] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/utilities.jl:25 [inlined]
 [12] call_kernel()
    @ Main ./REPL[3]:12
 [13] top-level scope
    @ REPL[4]:1
 [14] top-level scope
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/initialization.jl:52

Any form of help would be highly appreciated :)

carstenbauer avatar Mar 01 '22 20:03 carstenbauer

On first sight, there are two issues:

  1. This error means that the name of your wrapper is not correct:
Reason: unsupported use of an undefined name (use of 'llvm_wmma_mma_col_col_m16n16k8_tf32_tf32')

This name is generated at https://github.com/JuliaGPU/CUDA.jl/blob/c24234d3a9730727113086f917a3d2662a4da782/src/device/intrinsics/wmma.jl#L334, and with your definition at https://github.com/JuliaGPU/CUDA.jl/blob/c24234d3a9730727113086f917a3d2662a4da782/src/device/intrinsics/wmma.jl#L103, the correct name should be WMMA.llvm_wmma_mma_col_col_m16n16k8_f32_f32 instead of WMMA.llvm_wmma_mma_col_col_m16n16k8_tf32_tf32.

  1. This type of errors indicates that the name of the LLVM IR intrinsic is wrong:
Reason: unsupported call to an unknown function (call to llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8)

Have you verified that this is the correct name of the intrinsic? You can do so by compiling a CUDA C source file that uses TF32 WMMA using Clang, and using the command-line option -S -emit-llvm. If the name is correct, we're likely missing the latest NVPTX patches in Julia.

thomasfaingnaert avatar Mar 01 '22 21:03 thomasfaingnaert

Have you verified that this is the correct name of the intrinsic? You can do so by compiling a CUDA C source file that uses TF32 WMMA using Clang, and using the command-line option -S -emit-llvm.

Actually, it might be easier to just look at the tests of Clang for the expected intrinsic names. This unit test seems to contain all the intrinsic names for WMMA: https://github.com/llvm/llvm-project/blob/main/clang/test/CodeGen/builtins-nvptx-mma.cu.

thomasfaingnaert avatar Mar 02 '22 07:03 thomasfaingnaert

Thanks for your comments. I fixed (1):

function kernel_wmma_tf32_lowlevel(a_dev, b_dev, c_dev, d_dev)
    a_frag = WMMA.llvm_wmma_load_a_col_m16n16k8_global_stride_tf32(pointer(a_dev), 16)
    b_frag = WMMA.llvm_wmma_load_b_col_m16n16k8_global_stride_tf32(pointer(b_dev), 8)
    c_frag = WMMA.llvm_wmma_load_c_col_m16n16k8_global_stride_f32(pointer(c_dev), 16)

    d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k8_f32_f32(a_frag, b_frag, c_frag)

    WMMA.llvm_wmma_store_d_col_m16n16k8_global_stride_f32(pointer(d_dev), d_frag, 16)
    return nothing
end

function call_kernel()
    m = n = 16
    k = 8
    dtype_a = dtype_b = Float32
    dtype_c = dtype_d = Float32

    d_a = CUDA.rand(dtype_a, m, k)
    d_b = CUDA.rand(dtype_b, k, n)
    d_c = CUDA.rand(dtype_c, m, n)
    d_d = CUDA.zeros(dtype_d, m, n)

    CUDA.@sync @cuda kernel_wmma_tf32_lowlevel(d_a, d_b, d_c, d_d)
    return nothing
end

As for (2), here it says

llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32

so I think that the intrinsic is correct. (Am I right in assuming that the names in the comments are the ones to check?)

If the name is correct, we're likely missing the latest NVPTX patches in Julia.

Can you elaborate? Do you mean that our LLVM is too old and we need specific NVPTX patches / backports? If so, I guess that would imply that there is nothing that can be done here (and that things can only work from Julia 1.9 on?)

carstenbauer avatar Mar 02 '22 16:03 carstenbauer

I should probably say that I've used Julia 1.7 above. Trying Julia 1.8-beta1, which uses LLVM 13 (instead of 12) the loads seem to pass and I get

julia> call_kernel()                                                                                                                                                               
ERROR: InvalidIRError: compiling kernel kernel_wmma_tf32_lowlevel(CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}) r
esulted in invalid LLVM IR                                                                                                                                                         
Reason: unsupported call to an unknown function (call to llvm.nvvm.wmma.m16n16k8.mma.col.col.f32.f32)                                                                              
Stacktrace:                                                                                                                                                                        
 [1] llvm_wmma_mma_col_col_m16n16k8_f32_f32                                                                                                                                        
   @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/device/intrinsics/wmma.jl:355                                                                          
 [2] kernel_wmma_tf32_lowlevel                                                                                                                                                     
   @ ./REPL[3]:6                                                                                                                                                                   
HINT: catch this exception as `err` and call `code_typed(err; interactive = true)` to introspect the erronous code                                                                 
Stacktrace:                                                                                                                                                                        
  [1] check_ir(job::GPUCompiler.CompilerJob{GPUCompiler.PTXCompilerTarget, CUDA.CUDACompilerParams, GPUCompiler.FunctionSpec{typeof(kernel_wmma_tf32_lowlevel), NTuple{4, CuDeviceM
atrix{Float32, 1}}}}, args::LLVM.Module)                                                                                                                                           
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/validation.jl:119                                                                            [2] macro expansion                                                                                                                                                              
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:327 [inlined]      
  [3] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
  [4] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:325 [inlined]      
  [5] emit_asm(job::GPUCompiler.CompilerJob, ir::LLVM.Module; strip::Bool, validate::Bool, format::LLVM.API.LLVMCodeGenFileType)
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/utils.jl:64
  [6] cufunction_compile(job::GPUCompiler.CompilerJob)                                                                                                                             
    @ CUDA /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:326
  [7] cached_compilation(cache::Dict{UInt64, Any}, job::GPUCompiler.CompilerJob, compiler::typeof(CUDA.cufunction_compile), linker::typeof(CUDA.cufunction_link))
    @ GPUCompiler /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/cache.jl:90        
  [8] cufunction(f::typeof(kernel_wmma_tf32_lowlevel), tt::Type{NTuple{4, CuDeviceMatrix{Float32, 1}}}; name::Nothing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), 
Tuple{}}})
    @ CUDA /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:297
  [9] cufunction
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:291 [inlined]
 [10] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:102 [inlined]
 [11] macro expansion
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/utilities.jl:25 [inlined]
 [12] call_kernel()
    @ Main ./REPL[4]:12
 [13] top-level scope
    @ REPL[5]:1
 [14] top-level scope
    @ /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/initialization.jl:52

So, now it complains about llvm.nvvm.wmma.m16n16k8.mma.col.col.f32.f32. That seems to be a fixable mistake on my end, since the correct intrinsic seems to be llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32.

carstenbauer avatar Mar 02 '22 17:03 carstenbauer

Good news: with the current state (and Julia >= 1.8) the errors from above are gone with this corrected example:

function kernel_wmma_tf32_lowlevel(a_dev, b_dev, c_dev, d_dev)
    a_frag = WMMA.llvm_wmma_load_a_col_m16n16k8_global_stride_tf32(pointer(a_dev), 16)
    b_frag = WMMA.llvm_wmma_load_b_col_m16n16k8_global_stride_tf32(pointer(b_dev), 8)
    c_frag = WMMA.llvm_wmma_load_c_col_m16n16k8_global_stride_f32(pointer(c_dev), 16)

    d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k8_tf32(a_frag, b_frag, c_frag)

    WMMA.llvm_wmma_store_d_col_m16n16k8_global_stride_f32(pointer(d_dev), d_frag, 16)
    return nothing
end

function call_kernel()
    m = n = 16
    k = 8
    dtype_a = dtype_b = Float32
    dtype_c = dtype_d = Float32

    d_a = CUDA.rand(dtype_a, m, k)
    d_b = CUDA.rand(dtype_b, k, n)
    d_c = CUDA.rand(dtype_c, m, n)
    d_d = CUDA.zeros(dtype_d, m, n)

    CUDA.@sync @cuda kernel_wmma_tf32_lowlevel(d_a, d_b, d_c, d_d)
    return nothing
end

Bad news: I now get a segfault instead 😄

signal (11): Segmentation fault                                                                                                                                                    
in expression starting at REPL[5]:1                                                                                                                                                
_ZN4llvm19MachineRegisterInfo17constrainRegClassENS_8RegisterEPKNS_19TargetRegisterClassEj at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/jul
ia/libLLVM-13jl.so (unknown line)                                                                                                                                                  
_ZN4llvm12InstrEmitter18AddRegisterOperandERNS_19MachineInstrBuilderENS_7SDValueEjPKNS_11MCInstrDescERNS_8DenseMapIS3_NS_8RegisterENS_12DenseMapInfoIS3_EENS_6detail12DenseMapPairI
S3_S8_EEEEbbb at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)                                            
_ZN4llvm12InstrEmitter15EmitMachineNodeEPNS_6SDNodeEbbRNS_8DenseMapINS_7SDValueENS_8RegisterENS_12DenseMapInfoIS4_EENS_6detail12DenseMapPairIS4_S5_EEEE at /scratch/pc2-mitarbeiter
/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)                                                                                     
_ZN4llvm18ScheduleDAGSDNodes12EmitScheduleERNS_26MachineInstrBundleIteratorINS_12MachineInstrELb0EEE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/
../lib/julia/libLLVM-13jl.so (unknown line)                                                                                                                                        
_ZN4llvm16SelectionDAGISel17CodeGenAndEmitDAGEv at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)          
_ZN4llvm16SelectionDAGISel20SelectAllBasicBlocksERKNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknow
n line)                                                                                                                                                                            
_ZN4llvm16SelectionDAGISel20runOnMachineFunctionERNS_15MachineFunctionE.part.899 at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM
-13jl.so (unknown line)                                                                                                                                                            
_ZN4llvm19MachineFunctionPass13runOnFunctionERNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown lin
e)
_ZN4llvm13FPPassManager13runOnFunctionERNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm13FPPassManager11runOnModuleERNS_6ModuleE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZN4llvm6legacy15PassManagerImpl3runERNS_6ModuleE at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
_ZL21LLVMTargetMachineEmitP23LLVMOpaqueTargetMachineP16LLVMOpaqueModuleRN4llvm17raw_pwrite_streamE19LLVMCodeGenFileTypePPc at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-
1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMTargetMachineEmitToMemoryBuffer at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/bin/../lib/julia/libLLVM-13jl.so (unknown line)
LLVMTargetMachineEmitToMemoryBuffer at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/LLVM/MJqe4/lib/13/libLLVM_h.jl:947 [inlined]
emit at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/LLVM/MJqe4/src/targetmachine.jl:45
mcgen at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/mcgen.jl:74
unknown function (ip: 0x14bc65bb239f)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2340 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2522
macro expansion at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
macro expansion at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:339 [inlined]
macro expansion at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/TimerOutputs/5tW2E/src/TimerOutput.jl:252 [inlined]
macro expansion at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/driver.jl:336 [inlined]
#emit_asm#137 at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/utils.jl:64
emit_asm##kw at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/utils.jl:62 [inlined]
cufunction_compile at /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:326
cached_compilation at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/cache.jl:90
#cufunction#255 at /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:297
cufunction at /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:291 [inlined]
macro expansion at /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/compiler/execution.jl:102 [inlined]                                                      
macro expansion at /scratch/pc2-mitarbeiter/bauerc/devel/PC2GPUBenchmarks.jl/dev/CUDA/src/utilities.jl:25 [inlined]                                                                
call_kernel at ./REPL[3]:12                                                                                                                                                        
unknown function (ip: 0x14bc65b3e19f)                                                                                                                                              
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2340 [inlined]                                                                                                    
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2522                                                                                                       
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1825 [inlined]                                                                                                   
do_call at /buildworker/worker/package_linux64/build/src/interpreter.c:126                                                                                                         
eval_value at /buildworker/worker/package_linux64/build/src/interpreter.c:215                                                                                                      
eval_stmt_value at /buildworker/worker/package_linux64/build/src/interpreter.c:166 [inlined]                                                                                       
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:612                                                                                                       
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750                                                                                     
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906                                                                                              
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:850                                                                                              
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:556                                                                                                       
eval_body at /buildworker/worker/package_linux64/build/src/interpreter.c:522                                                                                                       
jl_interpret_toplevel_thunk at /buildworker/worker/package_linux64/build/src/interpreter.c:750
jl_toplevel_eval_flex at /buildworker/worker/package_linux64/build/src/toplevel.c:906
ijl_toplevel_eval_in at /buildworker/worker/package_linux64/build/src/toplevel.c:965
eval at ./boot.jl:368 [inlined]                                                                                                                                                    
eval_user_input at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:151
repl_backend_loop at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:247
start_repl_backend at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:232
#run_repl#47 at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:369
run_repl at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/REPL/src/REPL.jl:356
jfptr_run_repl_64273.clone_1 at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2340 [inlined]                                                                                                    
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2522         
#960 at ./client.jl:419                                                                  
jfptr_YY.960_31015.clone_1 at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2340 [inlined]                                                                                                    
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2522                                                                                                       
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1825 [inlined]                                                                                                   
jl_f__call_latest at /buildworker/worker/package_linux64/build/src/builtins.c:769                                                                                                  
#invokelatest#2 at ./essentials.jl:729 [inlined]                                                                                                                                   
invokelatest at ./essentials.jl:727 [inlined]                                                                                                                                      
run_main_repl at ./client.jl:404                                                         
exec_options at ./client.jl:318                                                          
_start at ./client.jl:522
jfptr__start_59889.clone_1 at /scratch/pc2-mitarbeiter/bauerc/.julia/juliaup/julia-1.8.0-beta1+0~x64/lib/julia/sys.so (unknown line)
_jl_invoke at /buildworker/worker/package_linux64/build/src/gf.c:2340 [inlined]
ijl_apply_generic at /buildworker/worker/package_linux64/build/src/gf.c:2522
jl_apply at /buildworker/worker/package_linux64/build/src/julia.h:1825 [inlined]
true_main at /buildworker/worker/package_linux64/build/src/jlapi.c:562           
jl_repl_entrypoint at /buildworker/worker/package_linux64/build/src/jlapi.c:706
main at julia-beta (unknown line)                                                        
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x400808)
Allocations: 46651882 (Pool: 46631992; Big: 19890); GC: 45
Segmentation fault (core dumped)

carstenbauer avatar Mar 02 '22 17:03 carstenbauer

so I think that the intrinsic is correct. (Am I right in assuming that the names in the comments are the ones to check?)

Yes, exactly.

If the name is correct, we're likely missing the latest NVPTX patches in Julia.

Can you elaborate? Do you mean that our LLVM is too old and we need specific NVPTX patches / backports? If so, I guess that would imply that there is nothing that can be done here (and that things can only work from Julia 1.9 on?)

Yes, this is probably the patch that you need, and is missing in Julia 1.7: https://reviews.llvm.org/D104847.

Bad news: I now get a segfault instead

Looks like a bug in instruction selection in LLVM. I think @maleadt can help you more than I can with that issue.

thomasfaingnaert avatar Mar 03 '22 07:03 thomasfaingnaert

Note that I see a similar (the same?) segfault with f64 wmma: https://github.com/JuliaGPU/CUDA.jl/pull/1426#issuecomment-1057864919

carstenbauer avatar Mar 03 '22 09:03 carstenbauer

Try using an assertions build. That should become easier as soon as I tag the next version of LLVM.jl (today, normally).

maleadt avatar Mar 03 '22 10:03 maleadt

Try using an assertions build. That should become easier as soon as I tag the next version of LLVM.jl (today, normally).

When I use an assertions build of julia I get

julia> using CUDA
[ Info: Precompiling CUDA [052768ef-5323-5732-b1bb-66c8b64840ba]
┌ Warning: You are using a version of Julia that links against a build of LLVM with assertions enabled.
│
│ This is not supported out-of-the-box, and you need a build of libLLVMExtra that supports this.
│ Use `deps/build_local.jl` for that, add the resulting LocalPreferences.toml to your project
│ and add a direct dependency on LLVMExtra_jll to pick up those preferences.
└ @ LLVM /scratch/pc2-mitarbeiter/bauerc/.julia/packages/LLVM/MJqe4/src/LLVM.jl:24
[...]

AFAICT, there is no deps/build_local.jl. Can you elaborate a bit more what I should do?

carstenbauer avatar Mar 03 '22 14:03 carstenbauer

Alright, I ] uped and then the assertion build didn't error. I get

julia> call_kernel()
julia: /workspace/srcdir/llvm-project/llvm/include/llvm/CodeGen/SelectionDAGNodes.h:1105: llvm::SDValue::SDValue(llvm::SDNode*, unsigned int): Assertion `(!Node |
| !ResNo || ResNo < Node->getNumValues()) && "Invalid result number for the given node!"' failed.

signal (6): Aborted
in expression starting at REPL[9]:1
gsignal at /lib64/libc.so.6 (unknown line)
abort at /lib64/libc.so.6 (unknown line)
__assert_fail_base.cold.0 at /lib64/libc.so.6 (unknown line)
__assert_fail at /lib64/libc.so.6 (unknown line)
_ZN4llvm7SDValueC1EPNS_6SDNodeEj at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm16SelectionDAGISel9MorphNodeEPNS_6SDNodeEjNS_8SDVTListENS_8ArrayRefINS_7SDValueEEEj at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin
/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm16SelectionDAGISel16SelectCodeCommonEPNS_6SDNodeEPKhj at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unkno
wn line)
_ZN4llvm16SelectionDAGISel22DoInstructionSelectionEv at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm16SelectionDAGISel17CodeGenAndEmitDAGEv at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm16SelectionDAGISel20SelectAllBasicBlocksERKNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (un
known line)
_ZN4llvm16SelectionDAGISel20runOnMachineFunctionERNS_15MachineFunctionE.part.972 at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/lib
LLVM-13jl.so (unknown line)
_ZN4llvm19MachineFunctionPass13runOnFunctionERNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown
 line)
_ZN4llvm13FPPassManager13runOnFunctionERNS_8FunctionE at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm13FPPassManager11runOnModuleERNS_6ModuleE at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZN4llvm6legacy15PassManagerImpl3runERNS_6ModuleE at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
_ZL21LLVMTargetMachineEmitP23LLVMOpaqueTargetMachineP16LLVMOpaqueModuleRN4llvm17raw_pwrite_streamE19LLVMCodeGenFileTypePPc at /scratch/pc2-mitarbeiter/bauerc/buil
ding/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
LLVMTargetMachineEmitToMemoryBuffer at /scratch/pc2-mitarbeiter/bauerc/building/julia/julia-source/usr/bin/../lib/libLLVM-13jl.so (unknown line)
LLVMTargetMachineEmitToMemoryBuffer at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/LLVM/P1t7P/lib/13/libLLVM_h.jl:947 [inlined]
emit at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/LLVM/P1t7P/src/targetmachine.jl:45
mcgen at /scratch/pc2-mitarbeiter/bauerc/.julia/packages/GPUCompiler/I9fZc/src/mcgen.jl:74
unknown function (ip: 0x152e1961abdf)
[...]

carstenbauer avatar Mar 03 '22 14:03 carstenbauer

Can you show the IR?

maleadt avatar Mar 03 '22 15:03 maleadt

You want the LLVM IR, i.e. @device_code_llvm, right? I get something but then it segfaults as well.

julia> @device_code_llvm dump_module=true call_kernel()                                                                                                                                                                         [126/239]
; PTX CompilerJob of kernel kernel_wmma_tf32_lowlevel(CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}) for sm_80                                                          
; ModuleID = 'text'                                                                                                                                                                                                                      
source_filename = "text"                                                                                                                                                                                                                 
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"                                                                                  
target triple = "nvptx64-nvidia-cuda"                                                                                                                                                                                                    
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                     
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                     
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                      
                                                                                                                                                                                                                                         
; Function Attrs: nounwind readnone                                                                                                                                                                                                      
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float
, float, float, float, float, float, float) #1                                                                                                                                                                                           
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind writeonly                                                                                                                                                                                          
declare void @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32.p1i8(i8 addrspace(1)* nocapture writeonly, float, float, float, float, float, float, float, float, i32) #2                                                                  
                                                                                                                                                                                                                                         
;  @ REPL[1]:1 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
define ptx_kernel void @_Z36julia_kernel_wmma_tf32_lowlevel_263813CuDeviceArrayI7Float32Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EE([1 x i64] %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { i8 addrspace(1)*, i64, [
2 x i64], i64 } %1, { i8 addrspace(1)*, i64, [2 x i64], i64 } %2, { i8 addrspace(1)*, i64, [2 x i64], i64 } %3) local_unnamed_addr #3 {                                                                                                  
entry:                                                                                                                                                                                                                                   
  %.fca.0.extract72 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0                                                                                                                                                       
  %.fca.0.extract62 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0                                                                                                                                                       
  %.fca.0.extract52 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %2, 0                                                                                                                                                       
  %.fca.0.extract49 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %3, 0                                                                                                                                                       
;  @ REPL[1]:2 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_a_col_m16n16k8_global_stride_tf32`                                                                                             
   %4 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8(i8 addrspace(1)* %.fca.0.extract72, i32 16)                                                                 
   %.fca.0.extract33 = extractvalue { float, float, float, float, float, float, float, float } %4, 0                                                                                                                                     
   %.fca.1.extract34 = extractvalue { float, float, float, float, float, float, float, float } %4, 1                                                                                                                                     
   %.fca.2.extract35 = extractvalue { float, float, float, float, float, float, float, float } %4, 2                                                                                                                                     
   %.fca.3.extract36 = extractvalue { float, float, float, float, float, float, float, float } %4, 3                                                                                                                                     
   %.fca.4.extract37 = extractvalue { float, float, float, float, float, float, float, float } %4, 4                                                                                                                                     
   %.fca.5.extract38 = extractvalue { float, float, float, float, float, float, float, float } %4, 5                                                                                                                                     
   %.fca.6.extract39 = extractvalue { float, float, float, float, float, float, float, float } %4, 6                                                                                                                                     
   %.fca.7.extract40 = extractvalue { float, float, float, float, float, float, float, float } %4, 7                                                                                                                                     
; └                                                                                                                                                                                                                                      
;  @ REPL[1]:3 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_b_col_m16n16k8_global_stride_tf32`                                                                                             
   %5 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8(i8 addrspace(1)* %.fca.0.extract62, i32 8)                                                                  
   %.fca.0.extract17 = extractvalue { float, float, float, float, float, float, float, float } %5, 0                                                                                                                                     
   %.fca.1.extract18 = extractvalue { float, float, float, float, float, float, float, float } %5, 1                                                                                                                                     
   %.fca.2.extract19 = extractvalue { float, float, float, float, float, float, float, float } %5, 2                                                                                                                                     
   %.fca.3.extract20 = extractvalue { float, float, float, float, float, float, float, float } %5, 3                                                                                                                                     
   %.fca.4.extract21 = extractvalue { float, float, float, float, float, float, float, float } %5, 4                                                                                                                                     
   %.fca.5.extract22 = extractvalue { float, float, float, float, float, float, float, float } %5, 5                                                                                                                                     
   %.fca.6.extract23 = extractvalue { float, float, float, float, float, float, float, float } %5, 6                                                                                                                                     
   %.fca.7.extract24 = extractvalue { float, float, float, float, float, float, float, float } %5, 7                                                                                                                                     
; └                                                                                                                                                                                                                                      
;  @ REPL[1]:4 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_c_col_m16n16k8_global_stride_f32`                                                                                              
   %6 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32.p1i8(i8 addrspace(1)* %.fca.0.extract52, i32 16)                                                                  
   %.fca.0.extract1 = extractvalue { float, float, float, float, float, float, float, float } %6, 0
   %.fca.1.extract2 = extractvalue { float, float, float, float, float, float, float, float } %6, 1
   %.fca.2.extract3 = extractvalue { float, float, float, float, float, float, float, float } %6, 2
   %.fca.3.extract4 = extractvalue { float, float, float, float, float, float, float, float } %6, 3
   %.fca.4.extract5 = extractvalue { float, float, float, float, float, float, float, float } %6, 4
   %.fca.5.extract6 = extractvalue { float, float, float, float, float, float, float, float } %6, 5
   %.fca.6.extract7 = extractvalue { float, float, float, float, float, float, float, float } %6, 6
   %.fca.7.extract8 = extractvalue { float, float, float, float, float, float, float, float } %6, 7
; └
;  @ REPL[1]:6 within `kernel_wmma_tf32_lowlevel`
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:355 within `llvm_wmma_mma_col_col_m16n16k8_tf32`
   %7 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(float %.fca.0.extract33, float %.fca.1.extract34, float %.fca.2.extract35, float %.fca.3.extract36, float %.fca.4.extra
ct37, float %.fca.5.extract38, float %.fca.6.extract39, float %.fca.7.extract40, float %.fca.0.extract17, float %.fca.1.extract18, float %.fca.2.extract19, float %.fca.3.extract20, float %.fca.4.extract21, float %.fca.5.extract22, fl
oat %.fca.6.extract23, float %.fca.7.extract24, float %.fca.0.extract1, float %.fca.1.extract2, float %.fca.2.extract3, float %.fca.3.extract4, float %.fca.4.extract5, float %.fca.5.extract6, float %.fca.6.extract7, float %.fca.7.ext
ract8)
   %.fca.0.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 0
   %.fca.1.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 1
   %.fca.2.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 2
   %.fca.3.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 3
   %.fca.4.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 4
   %.fca.5.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 5
   %.fca.6.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 6
   %.fca.7.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 7
; └
;  @ REPL[1]:8 within `kernel_wmma_tf32_lowlevel`
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:277 within `llvm_wmma_store_d_col_m16n16k8_global_stride_f32`
   call void @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32.p1i8(i8 addrspace(1)* %.fca.0.extract49, float %.fca.0.extract, float %.fca.1.extract, float %.fca.2.extract, float %.fca.3.extract, float %.fca.4.extract, float %.fca.5.ex
tract, float %.fca.6.extract, float %.fca.7.extract, i32 16)
; └
;  @ REPL[1]:9 within `kernel_wmma_tf32_lowlevel`
  ret void
}

attributes #0 = { argmemonly nounwind readonly }
attributes #1 = { nounwind readnone }
attributes #2 = { argmemonly nounwind writeonly }
attributes #3 = { "probe-stack"="inline-asm" }

!llvm.module.flags = !{!0, !1}
!nvvm.annotations = !{!2}

!0 = !{i32 2, !"Dwarf Version", i32 4}
!1 = !{i32 2, !"Debug Info Version", i32 3}
!2 = !{void ([1 x i64], { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 })* @_Z36julia_kernel_wmma_tf32_lowlevel
_263813CuDeviceArrayI7Float32Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EE, !"kernel", i32 1}
julia: /workspace/srcdir/llvm-project/llvm/include/llvm/CodeGen/SelectionDAGNodes.h:1105: llvm::SDValue::SDValue(llvm::SDNode*, unsigned int): Assertion `(!Node || !ResNo || ResNo < Node->getNumValues()) && "Invalid result number for
 the given node!"' failed.

signal (6): Aborted
[...]

carstenbauer avatar Mar 03 '22 18:03 carstenbauer

You want to include dump_module=true so that we can process the IR with llc outside of Julia. If that still asserts, we can try reducing it to figure out if we're generating bad code, or whether this is an LLVM MC issue.

maleadt avatar Mar 04 '22 06:03 maleadt

You want to include dump_module=true so that we can process the IR with llc outside of Julia. If that still asserts, we can try reducing it to figure out if we're generating bad code, or whether this is an LLVM MC issue.

I updated my last comment and included dump_module=true.

carstenbauer avatar Mar 04 '22 14:03 carstenbauer

Endlines were corrupted. Here's the fixed version:

; PTX CompilerJob of kernel kernel_wmma_tf32_lowlevel(CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}, CuDeviceMatrix{Float32, 1}) for sm_80                                                          
; ModuleID = 'text'                                                                                                                                                                                                                      
source_filename = "text"                                                                                                                                                                                                                 
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"                                                                                  
target triple = "nvptx64-nvidia-cuda"                                                                                                                                                                                                    
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                     
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                     
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind readonly                                                                                                                                                                                           
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32.p1i8(i8 addrspace(1)* nocapture readonly, i32) #0                                                                      
                                                                                                                                                                                                                                         
; Function Attrs: nounwind readnone                                                                                                                                                                                                      
declare { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float , float, float, float, float, float, float) #1                                                                                                                                                                                           
                                                                                                                                                                                                                                         
; Function Attrs: argmemonly nounwind writeonly                                                                                                                                                                                          
declare void @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32.p1i8(i8 addrspace(1)* nocapture writeonly, float, float, float, float, float, float, float, float, i32) #2                                                                  
                                                                                                                                                                                                                                         
;  @ REPL[1]:1 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
define ptx_kernel void @_Z36julia_kernel_wmma_tf32_lowlevel_263813CuDeviceArrayI7Float32Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EE([1 x i64] %state, { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, { i8 addrspace(1)*, i64, [ 2 x i64], i64 } %1, { i8 addrspace(1)*, i64, [2 x i64], i64 } %2, { i8 addrspace(1)*, i64, [2 x i64], i64 } %3) local_unnamed_addr #3 {                                                                                                  
entry:                                                                                                                                                                                                                                   
  %.fca.0.extract72 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %0, 0                                                                                                                                                       
  %.fca.0.extract62 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %1, 0                                                                                                                                                       
  %.fca.0.extract52 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %2, 0                                                                                                                                                       
  %.fca.0.extract49 = extractvalue { i8 addrspace(1)*, i64, [2 x i64], i64 } %3, 0                                                                                                                                                       
;  @ REPL[1]:2 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_a_col_m16n16k8_global_stride_tf32`                                                                                             
   %4 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8(i8 addrspace(1)* %.fca.0.extract72, i32 16)                                                                 
   %.fca.0.extract33 = extractvalue { float, float, float, float, float, float, float, float } %4, 0                                                                                                                                     
   %.fca.1.extract34 = extractvalue { float, float, float, float, float, float, float, float } %4, 1                                                                                                                                     
   %.fca.2.extract35 = extractvalue { float, float, float, float, float, float, float, float } %4, 2                                                                                                                                     
   %.fca.3.extract36 = extractvalue { float, float, float, float, float, float, float, float } %4, 3                                                                                                                                     
   %.fca.4.extract37 = extractvalue { float, float, float, float, float, float, float, float } %4, 4                                                                                                                                     
   %.fca.5.extract38 = extractvalue { float, float, float, float, float, float, float, float } %4, 5                                                                                                                                     
   %.fca.6.extract39 = extractvalue { float, float, float, float, float, float, float, float } %4, 6                                                                                                                                     
   %.fca.7.extract40 = extractvalue { float, float, float, float, float, float, float, float } %4, 7                                                                                                                                     
; └                                                                                                                                                                                                                                      
;  @ REPL[1]:3 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_b_col_m16n16k8_global_stride_tf32`                                                                                             
   %5 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8(i8 addrspace(1)* %.fca.0.extract62, i32 8)                                                                  
   %.fca.0.extract17 = extractvalue { float, float, float, float, float, float, float, float } %5, 0                                                                                                                                     
   %.fca.1.extract18 = extractvalue { float, float, float, float, float, float, float, float } %5, 1                                                                                                                                     
   %.fca.2.extract19 = extractvalue { float, float, float, float, float, float, float, float } %5, 2                                                                                                                                     
   %.fca.3.extract20 = extractvalue { float, float, float, float, float, float, float, float } %5, 3                                                                                                                                     
   %.fca.4.extract21 = extractvalue { float, float, float, float, float, float, float, float } %5, 4                                                                                                                                     
   %.fca.5.extract22 = extractvalue { float, float, float, float, float, float, float, float } %5, 5                                                                                                                                     
   %.fca.6.extract23 = extractvalue { float, float, float, float, float, float, float, float } %5, 6                                                                                                                                     
   %.fca.7.extract24 = extractvalue { float, float, float, float, float, float, float, float } %5, 7                                                                                                                                     
; └                                                                                                                                                                                                                                      
;  @ REPL[1]:4 within `kernel_wmma_tf32_lowlevel`                                                                                                                                                                                        
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:214 within `llvm_wmma_load_c_col_m16n16k8_global_stride_f32`                                                                                              
   %6 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.load.c.col.stride.f32.p1i8(i8 addrspace(1)* %.fca.0.extract52, i32 16)                                                                  
   %.fca.0.extract1 = extractvalue { float, float, float, float, float, float, float, float } %6, 0
   %.fca.1.extract2 = extractvalue { float, float, float, float, float, float, float, float } %6, 1
   %.fca.2.extract3 = extractvalue { float, float, float, float, float, float, float, float } %6, 2
   %.fca.3.extract4 = extractvalue { float, float, float, float, float, float, float, float } %6, 3
   %.fca.4.extract5 = extractvalue { float, float, float, float, float, float, float, float } %6, 4
   %.fca.5.extract6 = extractvalue { float, float, float, float, float, float, float, float } %6, 5
   %.fca.6.extract7 = extractvalue { float, float, float, float, float, float, float, float } %6, 6
   %.fca.7.extract8 = extractvalue { float, float, float, float, float, float, float, float } %6, 7
; └
;  @ REPL[1]:6 within `kernel_wmma_tf32_lowlevel`
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:355 within `llvm_wmma_mma_col_col_m16n16k8_tf32`
   %7 = call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32(float %.fca.0.extract33, float %.fca.1.extract34, float %.fca.2.extract35, float %.fca.3.extract36, float %.fca.4.extract37, float %.fca.5.extract38, float %.fca.6.extract39, float %.fca.7.extract40, float %.fca.0.extract17, float %.fca.1.extract18, float %.fca.2.extract19, float %.fca.3.extract20, float %.fca.4.extract21, float %.fca.5.extract22, float %.fca.6.extract23, float %.fca.7.extract24, float %.fca.0.extract1, float %.fca.1.extract2, float %.fca.2.extract3, float %.fca.3.extract4, float %.fca.4.extract5, float %.fca.5.extract6, float %.fca.6.extract7, float %.fca.7.extract8)
   %.fca.0.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 0
   %.fca.1.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 1
   %.fca.2.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 2
   %.fca.3.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 3
   %.fca.4.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 4
   %.fca.5.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 5
   %.fca.6.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 6
   %.fca.7.extract = extractvalue { float, float, float, float, float, float, float, float } %7, 7
; └
;  @ REPL[1]:8 within `kernel_wmma_tf32_lowlevel`
; ┌ @ /scratch/pc2-mitarbeiter/bauerc/devel/CUDA/src/device/intrinsics/wmma.jl:277 within `llvm_wmma_store_d_col_m16n16k8_global_stride_f32` call void @llvm.nvvm.wmma.m16n16k8.store.d.col.stride.f32.p1i8(i8 addrspace(1)* %.fca.0.extract49, float %.fca.0.extract, float %.fca.1.extract, float %.fca.2.extract, float %.fca.3.extract, float %.fca.4.extract, float %.fca.5.ex tract, float %.fca.6.extract, float %.fca.7.extract, i32 16)
; └
;  @ REPL[1]:9 within `kernel_wmma_tf32_lowlevel`
  ret void
}

attributes #0 = { argmemonly nounwind readonly }
attributes #1 = { nounwind readnone }
attributes #2 = { argmemonly nounwind writeonly }
attributes #3 = { "probe-stack"="inline-asm" }

!llvm.module.flags = !{!0, !1}
!nvvm.annotations = !{!2}

!0 = !{i32 2, !"Dwarf Version", i32 4}
!1 = !{i32 2, !"Debug Info Version", i32 3}
!2 = !{void ([1 x i64], { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 }, { i8 addrspace(1)*, i64, [2 x i64], i64 })* @_Z36julia_kernel_wmma_tf32_lowlevel_263813CuDeviceArrayI7Float32Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EES_IS0_Li2ELi1EE, !"kernel", i32 1}

If I try to process this with llc I get:

Intrinsic has incorrect return type!
{ float, float, float, float, float, float, float, float } (i8 addrspace(1)*, i32)* @llvm.nvvm.wmma.m16n16k8.load.a.col.stride.tf32.p1i8
Intrinsic has incorrect return type!
{ float, float, float, float, float, float, float, float } (i8 addrspace(1)*, i32)* @llvm.nvvm.wmma.m16n16k8.load.b.col.stride.tf32.p1i8
Intrinsic has incorrect argument type!
{ float, float, float, float, float, float, float, float } (float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float, float)* @llvm.nvvm.wmma.m16n16k8.mma.col.col.tf32

So probably an intrinsic misuse?

maleadt avatar Mar 04 '22 14:03 maleadt

Thanks for helping me to debug this Tim! I was using the wrong fragment sizes. I fixed this in the latest commit for which call_kernel() now runs through without any issues. Will add more tests and update the docstrings next.

carstenbauer avatar Mar 04 '22 15:03 carstenbauer

TODOs

  • [x] add tests
  • [x] docstrings
  • [ ] (maybe high-level API)

carstenbauer avatar Jun 13 '22 21:06 carstenbauer

Note: Tests are failing since they're running under Julia 1.6 (and 1.8 is required for this PR). Anything I can / should do about it on my end @maleadt?

carstenbauer avatar Jun 14 '22 06:06 carstenbauer

Only 1.6 is used for draft PRs, but the tests shouldn't crash there, instead you should check VERSION and skip unsupported combinations. Once that's in there, we can mark the PR as non-draft and run all tests.

maleadt avatar Jun 14 '22 10:06 maleadt

If the tests pass (which they should) we can mark this as non-draft IMHO.

carstenbauer avatar Aug 18 '22 10:08 carstenbauer

LGTM. @thomasfaingnaert ?

Should also be rebased so that CI runs.

maleadt avatar Oct 17 '22 07:10 maleadt

Should also be rebased so that CI runs.

Done (I think).

carstenbauer avatar Oct 17 '22 08:10 carstenbauer

CI failures look related.

maleadt avatar Oct 17 '22 12:10 maleadt