KernelAbstractions.jl
KernelAbstractions.jl copied to clipboard
How to improve CPU performance?
consider this example from a Jax discussion:
Source code
function run_julia(height, width)
y = range(-1.0f0, 0.0f0; length = height) # need Float32 because Jax defaults to it
x = range(-1.5f0, 0.0f0; length = width)
c = x' .+ y*im
fractal = fill(Int32(20), height, width)
# this checks if indicies are compatible between `c` and `fractal`
@inbounds for idx in eachindex(c, fractal)
_c = c[idx]
z = _c
m = true
Base.Cartesian.@nexprs 20 i -> begin
z = z^2 + _c
az4 = abs2(z) > 4f0
fractal[idx] = ifelse(m&az4, Int32(i), fractal[idx]) # 32-bit Int, same reason as above
m &= (!az4)
end
end
return fractal
end
using KernelAbstractions
@kernel function julia_kernel!(c, fractal)
I = @index(Global)
_c = c[I]
z = _c
@inbounds for i = 1:20
z = z^2 + _c
if abs2(z) > 4f0
fractal[I] = Int32(i)
break
end
end
end
function run_julia_cpu_jaxstype(height, width)
y = range(-1.0f0, 0.0f0; length = height)
x = range(-1.5f0, 0.0f0; length = width)
c = x' .+ y*im
fractal = fill(Int32(20), height, width)
kernel! = julia_kernel!(CPU(), length(c)÷Threads.nthreads()) # we're using 1-thread here
event = kernel!(c, fractal; ndrange=length(c))
wait(event) # not copying back, need to block here
return fractal
end
we have
julia> Threads.nthreads()
1
julia> @benchmark run_julia(2000,3000)
BenchmarkTools.Trial: 100 samples with 1 evaluation.
Range (min … max): 49.380 ms … 52.606 ms ┊ GC (min … max): 0.37% … 1.04%
Time (median): 49.789 ms ┊ GC (median): 0.98%
Time (mean ± σ): 49.982 ms ± 583.818 μs ┊ GC (mean ± σ): 0.98% ± 0.07%
▅█▇▄ ▃
▃▁▅▇█████▃▇██▅▇▅▁▆▆▃▆▃▁▆▃▅▁▃▁▁▁▁▁▃▁▁▁▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▆ ▃
49.4 ms Histogram: frequency by time 52 ms <
Memory estimate: 68.66 MiB, allocs estimate: 4.
julia> @benchmark run_julia_cpu_jaxstype(2000,3000)
BenchmarkTools.Trial: 29 samples with 1 evaluation.
Range (min … max): 176.674 ms … 179.932 ms ┊ GC (min … max): 0.27% … 0.08%
Time (median): 177.228 ms ┊ GC (median): 0.26%
Time (mean ± σ): 177.528 ms ± 780.443 μs ┊ GC (mean ± σ): 0.26% ± 0.03%
▃▃ ▃▃ ▃ █
██▇██▇▁█▇▇▇▁▇▁▁▁▁▇▁▁▇▇▁▁▇▇█▇▁▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
177 ms Histogram: frequency by time 180 ms <
Memory estimate: 68.67 MiB, allocs estimate: 28.