CUDA.jl
CUDA.jl copied to clipboard
Make Ref pass by-reference
julia> a = CUDA.rand(1)
julia> kernel(a) = (@inbounds a[1] = 0; nothing)
kernel (generic function with 1 method)
julia> @device_code_ptx @cuda kernel(a)
// PTX CompilerJob of kernel kernel(CuDeviceArray{Float32,1,CUDA.AS.Global}) for sm_75
//
// Generated by LLVM NVPTX Back-End
//
.version 6.3
.target sm_75
.address_size 64
// .globl _Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE // -- Begin function _Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE
.weak .global .align 8 .u64 exception_flag;
// @_Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE
.visible .entry _Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE(
.param .align 8 .b8 _Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE_param_0[16]
)
{
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
// %bb.0: // %top
mov.b64 %rd1, _Z17julia_kernel_435613CuDeviceArrayI7Float32Li1E6GlobalE_param_0;
ld.param.u64 %rd2, [%rd1+8];
mov.u32 %r1, 0;
st.global.u32 [%rd2], %r1;
ret;
// -- End function
}
julia> kernel(a) = (@inbounds a[][1] = 1; nothing)
kernel (generic function with 1 method)
julia> @device_code_ptx @cuda kernel(Ref(a))
// PTX CompilerJob of kernel kernel(CUDA.CuRefValue{CuDeviceArray{Float32,1,CUDA.AS.Global}}) for sm_75
//
// Generated by LLVM NVPTX Back-End
//
.version 6.3
.target sm_75
.address_size 64
// .globl _Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE // -- Begin function _Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE
.weak .global .align 8 .u64 exception_flag;
// @_Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE
.visible .entry _Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE(
.param .align 8 .b8 _Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE_param_0[16]
)
{
.reg .b32 %r<2>;
.reg .b64 %rd<3>;
// %bb.0: // %top
mov.b64 %rd1, _Z17julia_kernel_433710CuRefValueI13CuDeviceArrayI7Float32Li1E6GlobalEE_param_0;
ld.param.u64 %rd2, [%rd1+8];
mov.u32 %r1, 1065353216;
st.global.u32 [%rd2], %r1;
ret;
// -- End function
}
Could be useful to work around parameter state space size restrictions like https://github.com/CliMA/Oceananigans.jl/pull/746. But might pessimize Broadcast operations where Ref is commonly used.
Hi @maleadt,
I've come across an issue similar to https://github.com/CliMA/Oceananigans.jl/pull/746 while I've been trying to build a biogeochemical model on top of Oceananigans. I was wondering if you had any updated advice on solving the issue or if the above suggestion would now work as a workaround?
Thanks, Jago
Solving which issue?
If you want pass-by-reference behavior, use an array instead of a Ref for now.
Sorry basically my issue is just trying to pass a massive parameter like in the Oceananigans issue so I think pass by reference behaviour is what I'm after.
How do you mean to use an array?
Like the example here, use a single-element array.
I see, thank you