KernelAbstractions.jl
KernelAbstractions.jl copied to clipboard
Performance discrepancy against CUDA
Reproducer:
CUDA.jl:
using CUDA
using Adapt
CUDA.allowscalar(false)
# using KernelAbstractions
is_valid_index(meta, ui) =
1 ≤ ui[1] ≤ params(meta)[4] &&
1 ≤ ui[2] ≤ params(meta)[1] &&
1 ≤ ui[3] ≤ params(meta)[2] &&
1 ≤ ui[4] ≤ params(meta)[5]
@inline function universal_index_cuda(meta, src)
(v,) = CUDA.threadIdx()
(h, ij) = CUDA.blockIdx()
(Ni, Nj, Nv, _, Nh) = params(meta)
Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1))
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
return CartesianIndex((v, i, j, 1, h))
end
function my_memcpy_cuda!(dest, src, meta)
(Ni, Nj, _, Nv, Nh) = params(meta)
kernel = CUDA.@cuda(
always_inline = true,
launch = false,
my_memcpy_kernel!(dest, src, meta)
)
threads = (Nv, )
blocks = (Nh, Ni * Nj)
kernel(dest, src, meta; threads, blocks)
return nothing
end
function my_memcpy_kernel!(dest, src, meta)
ui = universal_index_cuda(meta, src)
(Ni, Nj, _, Nv, Nh) = params(meta)
src_arr = CUDA.CuStaticSharedArray(eltype(src), (Nv,))
is_valid_index(meta, ui) && (src_arr[ui[1]] = src[ui])
CUDA.sync_threads()
if is_valid_index(meta, ui)
src_shmem = reconstruct_src(src, src_arr)
dest[ui] = src_shmem[ui[1]]
end
return nothing
end
reconstruct_src(src, src_arr) = src_arr
params(p::Tuple) = map(unval, p)
unval(::Val{n}) where {n} = n
unval(n) = n
Ni = 4
Nj = 4
Nv = 64
Nh = 5400
meta = (Val(Ni), Val(Nj), 1, Val(Nv), Val(Nh))
_src1 = CUDA.CuArray(rand(Float32, Nv, Ni, Nj, 1, Nh));
_src2 = similar(_src1); @. _src2 = _src1
_dest1 = similar(_src1);
_dest2 = similar(_src1);
my_memcpy_cuda!(_dest1, _src1, meta)
@. _dest2 = _src2
using Test
@test all(Array(_dest2) .== Array(_dest1))
CUDA.@profile begin
my_memcpy_cuda!(_dest1, _src1, meta)
my_memcpy_cuda!(_dest1, _src1, meta)
my_memcpy_cuda!(_dest1, _src1, meta)
end
KA.jl
using CUDA
using KernelAbstractions
using Adapt
CUDA.allowscalar(false)
using CUDA.CUDAKernels
# using KernelAbstractions
@inline function universal_index_KA(
meta,
ilc, # @index(Local, Cartesian)
igc, # @index(Group, Cartesian)
gs, # @groupsize()
)
@inbounds begin
v = ilc[1]
(_, h, ij) = igc.I
(Ni, Nj, _, Nv, Nh) = params(meta)
if Ni * Nj < ij
return CartesianIndex((-1, -1, 1, -1, -1))
else
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
return CartesianIndex((v, i, j, 1, h))
end
end
end
is_valid_index(meta, ui) =
1 ≤ ui[1] ≤ params(meta)[4] &&
1 ≤ ui[2] ≤ params(meta)[1] &&
1 ≤ ui[3] ≤ params(meta)[2] &&
1 ≤ ui[4] ≤ params(meta)[5]
function my_memcpy_KA!(dest, src, meta)
(Ni, Nj, _, Nv, Nh) = params(meta)
backend = CUDABackend()
group = (Nv, )
ndrange = (Nv, Nh, Ni * Nj)
kernel = my_memcpy_kernel!(backend, group)
kernel(dest, src, meta, ndrange = ndrange)
return nothing
end
@kernel function my_memcpy_kernel!(dest, src, meta)
ilc = @index(Local, Cartesian); igc = @index(Group, Cartesian); gs = @groupsize()
ui = universal_index_KA(meta, ilc, igc, gs)
(Ni, Nj, _, Nv, Nh) = @uniform params(meta)
src_arr = @localmem eltype(src) (Nv,)
is_valid_index(meta, ui) && (src_arr[ui[1]] = src[ui])
@synchronize
ilc = @index(Local, Cartesian); igc = @index(Group, Cartesian); gs = @groupsize()
ui = universal_index_KA(meta, ilc, igc, gs)
if is_valid_index(meta, ui)
src_shmem = reconstruct_src(src, src_arr)
dest[ui] = src_shmem[ui[1]]
end
end
reconstruct_src(src, src_arr) = src_arr
params(p::Tuple) = map(unval, p)
unval(::Val{n}) where {n} = n
unval(n) = n
Ni = 4
Nj = 4
Nv = 64
Nh = 5400
meta = (Val(Ni), Val(Nj), 1, Val(Nv), Val(Nh))
_src1 = CUDA.CuArray(rand(Float32, Nv, Ni, Nj, 1, Nh));
_src2 = similar(_src1); @. _src2 = _src1
_dest1 = similar(_src1);
_dest2 = similar(_src1);
my_memcpy_KA!(_dest1, _src1, meta)
@. _dest2 = _src2
using Test
@test all(Array(_dest2) .== Array(_dest1))
CUDA.@profile begin
my_memcpy_KA!(_dest1, _src1, meta)
my_memcpy_KA!(_dest1, _src1, meta)
my_memcpy_KA!(_dest1, _src1, meta)
end
Output:
julia> using Revise; include("../ka_reproducer.jl")
Profiler ran for 429.39 µs, capturing 68 events.
Host-side activity: calling CUDA APIs took 77.96 µs (18.16% of the trace)
┌──────────┬────────────┬───────┬────────────────────────────────────┬────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution │ Name │
├──────────┼────────────┼───────┼────────────────────────────────────┼────────────────┤
│ 17.38% │ 74.63 µs │ 3 │ 24.88 µs ± 33.83 ( 3.81 ‥ 63.9) │ cuLaunchKernel │
└──────────┴────────────┴───────┴────────────────────────────────────┴────────────────┘
Device-side activity: GPU was busy for 304.94 µs (71.02% of the trace)
┌──────────┬────────────┬───────┬──────────────────────────────────────┬─────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Time distribution │ Name ⋯
├──────────┼────────────┼───────┼──────────────────────────────────────┼─────────────────────────────────────────────────────────────────────
│ 71.02% │ 304.94 µs │ 3 │ 101.65 µs ± 0.36 (101.33 ‥ 102.04) │ gpu_my_memcpy_kernel_(CompilerMetadata<DynamicSize, DynamicCheck, ⋯
└──────────┴────────────┴───────┴──────────────────────────────────────┴─────────────────────────────────────────────────────────────────────
1 column omitted
julia> using Revise; include("../cuda_reproducer.jl")
Profiler ran for 341.42 µs, capturing 68 events.
Host-side activity: calling CUDA APIs took 70.1 µs (20.53% of the trace)
┌──────────┬────────────┬───────┬─────────────────────────────────────┬────────────────┐
│ Time (%) │ Total time │ Calls │ Time distribution │ Name │
├──────────┼────────────┼───────┼─────────────────────────────────────┼────────────────┤
│ 19.55% │ 66.76 µs │ 3 │ 22.25 µs ± 28.88 ( 4.05 ‥ 55.55) │ cuLaunchKernel │
└──────────┴────────────┴───────┴─────────────────────────────────────┴────────────────┘
Device-side activity: GPU was busy for 250.34 µs (73.32% of the trace)
┌──────────┬────────────┬───────┬─────────────────────────────────────┬──────────────────────────────────────────────────────────────────────
│ Time (%) │ Total time │ Calls │ Time distribution │ Name ⋯
├──────────┼────────────┼───────┼─────────────────────────────────────┼──────────────────────────────────────────────────────────────────────
│ 73.32% │ 250.34 µs │ 3 │ 83.45 µs ± 0.24 ( 83.21 ‥ 83.68) │ my_memcpy_kernel_(CuDeviceArray<Float32, 5, 1>, CuDeviceArray<Float ⋯
└──────────┴────────────┴───────┴─────────────────────────────────────┴──────────────────────────────────────────────────────────────────────
My real-world example is much more complicated, and incurs a 25x slowdown. I'm not sure why. Here is the real-world case:
(Toggle implicit_tendency! at the top of the script)
Click to see script
ENV["CLIMACOMMS_DEVICE"] = "CUDA";
high_res = true;
# implicit_tendency!(Yₜ, Y, p, t) = implicit_tendency_KA!(Yₜ, Y, p, t)
implicit_tendency!(Yₜ, Y, p, t) = implicit_tendency_cuda!(Yₜ, Y, p, t)
using CUDA
import ClimaComms
ClimaComms.@import_required_backends
using ClimaCore.CommonSpaces
import ClimaAtmos as CA
import ClimaCore.Fields.StaticArrays: MArray
using LazyBroadcast: lazy
using LinearAlgebra: ×, dot, norm
import ClimaAtmos.Parameters as CAP
import Thermodynamics as TD
import SciMLBase
import ClimaCore.Grids
import ClimaCore
using KernelAbstractions
import KernelAbstractions as KA
import ClimaTimeSteppers as CTS
import ClimaCore.Geometry
import ClimaCore.MatrixFields: @name, ⋅
import ClimaCore.MatrixFields: DiagonalMatrixRow, BidiagonalMatrixRow
import LinearAlgebra: Adjoint
import LinearAlgebra: adjoint
import LinearAlgebra as LA
import ClimaCore: Operators, Topologies, DataLayouts
import ClimaCore.MatrixFields
import ClimaCore.Spaces
import ClimaCore.Fields
# import KernelAbstractions as KA
# using KernelAbstractions
# Unless the kernel here fills the shmem, we cannot use the shmem getidx.
# So, let's disable for now. Maybe we can write a simple shmem transformation
# of the broadcasted object + the state.
const cccuda_ext = Base.get_extension(ClimaCore, :ClimaCoreCUDAExt);
# LazyBroadcast calls instantiate on intermediate broadcast expressions, so we
# should be determining AbstractStencilStyle locally.
# ClimaCore bug: stencil style should be locally determined at the top level, not using the recursive Operators.any_fd_shmem_supported(bc)
Operators.AbstractStencilStyle(bc, ::ClimaComms.CUDADevice) =
cccuda_ext.CUDAColumnStencilStyle
Operators.fd_shmem_is_supported(bc::Base.Broadcast.Broadcasted) = false
ClimaCore.Operators.use_fd_shmem() = false
@info "Arch: $(ClimaComms.device())"
if get(ENV, "CLIMACOMMS_DEVICE", "CPU") == "CUDA"
using CUDA
using CUDA.CUDAKernels
CUDA.allowscalar(false)
else
end
# allow on-device use of lazy broadcast objects
DataLayouts.parent_array_type(::Type{<:CUDA.CuDeviceArray{T, N, A} where {N}}) where {T, A} =
CUDA.CuDeviceArray{T, N, A} where {N}
# allow on-device use of lazy broadcast objects
DataLayouts.promote_parent_array_type(
::Type{CUDA.CuDeviceArray{T1, N, B} where {N}},
::Type{CUDA.CuDeviceArray{T2, N, B} where {N}},
) where {T1, T2, B} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N}
# Ditch sizes (they're never actually used!)
DataLayouts.promote_parent_array_type(
::Type{MArray{S1, T1}},
::Type{MArray{S2, T2}},
) where {S1, T1, S2, T2} = MArray{S, promote_type(T1, T2)} where {S}
DataLayouts.promote_parent_array_type(
::Type{MArray{S1, T1} where {S1}},
::Type{MArray{S2, T2}},
) where {T1, S2, T2} = MArray{S, promote_type(T1, T2)} where {S}
DataLayouts.promote_parent_array_type(
::Type{MArray{S1, T1}},
::Type{MArray{S2, T2} where {S2}},
) where {S1, T1, T2} = MArray{S, promote_type(T1, T2)} where {S}
# allow on-device use of lazy broadcast objects with different type params
DataLayouts.promote_parent_array_type(
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
) where {T1, T2, B1, B2} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
# allow on-device use of lazy broadcast objects with different type params
DataLayouts.promote_parent_array_type(
::Type{CUDA.CuDeviceArray{T1}},
::Type{CUDA.CuDeviceArray{T2, N, B2} where {N}},
) where {T1, T2, B2} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
DataLayouts.promote_parent_array_type(
::Type{CUDA.CuDeviceArray{T1, N, B1} where {N}},
::Type{CUDA.CuDeviceArray{T2} where {N}},
) where {T1, T2, B1} = CUDA.CuDeviceArray{promote_type(T1, T2), N, B} where {N, B}
# Specialize to allow on-device call of `device` for `DeviceExtrudedFiniteDifferenceGrid`
ClimaComms.device(grid::Grids.DeviceExtrudedFiniteDifferenceGrid) =
ClimaComms.device(Grids.vertical_topology(grid))
# The existing implementation limits our ability to apply the same expressions from within kernels
ClimaComms.device(topology::Topologies.DeviceIntervalTopology) = ClimaComms.CUDADevice()
Fields.error_mismatched_spaces(::Type, ::Type) = nothing # causes unsupported dynamic function invocation
import ClimaAtmos: C1, C2, C12, C3, C123, CT1, CT2, CT12, CT3, CT123, UVW
import ClimaAtmos:
divₕ, wdivₕ, gradₕ, wgradₕ, curlₕ, wcurlₕ, ᶜinterp, ᶜdivᵥ, ᶜgradᵥ
import ClimaAtmos: ᶠinterp, ᶠgradᵥ, ᶠcurlᵥ, ᶜinterp_matrix, ᶠgradᵥ_matrix
import ClimaAtmos: ᶜadvdivᵥ, ᶜadvdivᵥ_matrix, ᶠwinterp, ᶠinterp_matrix
Fields.local_geometry_field(bc::Base.Broadcast.Broadcasted) =
Fields.local_geometry_field(axes(bc))
ᶜtendencies(ρ, uₕ, ρe_tot) = (; ρ, uₕ, ρe_tot)
ᶠtendencies(u₃) = (; u₃)
@inline is_valid_index(us, ui) = 1 ≤ ui[4] ≤ DataLayouts.get_Nv(us)
function implicit_tendency_bc!(Yₜ, Y, p, t)
Yₜ .= zero(eltype(Yₜ))
set_precomputed_quantities!(Y, p, t)
(; rayleigh_sponge, params, dt) = p
(; ᶜh_tot, ᶠu³, ᶜp) = p.precomputed
ᶜJ = Fields.local_geometry_field(Y.c).J
ᶜz = Fields.coordinate_field(Y.c).z
ᶠz = Fields.coordinate_field(Y.f).z
grav = FT(CAP.grav(params))
zmax = CA.z_max(axes(Y.f))
@. Yₜ.c.ρ -= ᶜdivᵥ(ᶠwinterp(ᶜJ, Y.c.ρ) * ᶠu³)
# Central advection of active tracers (e_tot and q_tot)
Yₜ.c.ρe_tot .+= CA.vertical_transport(Y.c.ρ, ᶠu³, ᶜh_tot, dt, Val(:none))
@. Yₜ.f.u₃ -= ᶠgradᵥ(ᶜp) / ᶠinterp(Y.c.ρ) + ᶠgradᵥ(Φ(grav, ᶜz))
@. Yₜ.f.u₃ -= CA.β_rayleigh_w(rayleigh_sponge, ᶠz, zmax) * Y.f.u₃
return nothing
end
function implicit_tendency_cuda!(Yₜ, Y, p, t)
ᶜspace = axes(Y.c)
ᶠspace = Spaces.face_space(ᶜspace)
ᶠNv = Spaces.nlevels(ᶠspace)
ᶜcf = Fields.coordinate_field(ᶜspace)
us = DataLayouts.UniversalSize(Fields.field_values(ᶜcf))
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
nitems = Ni * Nj * 1 * ᶠNv * Nh
ᶜYₜ = Yₜ.c
ᶠYₜ = Yₜ.f
ᶜY = Y.c
ᶠY = Y.f
(; rayleigh_sponge, params, dt) = p
p_kernel = (; rayleigh_sponge, params, dt)
zmax = Spaces.z_max(axes(ᶠY)) # DeviceIntervalTopology does not have mesh, and therefore cannot compute zmax
kernel = CUDA.@cuda(
always_inline = true,
launch = false,
implicit_tendency_kernel_cuda!(ᶜYₜ, ᶠYₜ, ᶜY, ᶠY, p_kernel, t, zmax)
)
threads = (ᶠNv, )
blocks = (Nh, Ni * Nj)
kernel(ᶜYₜ, ᶠYₜ, ᶜY, ᶠY, p_kernel, t, zmax; threads, blocks)
end
function implicit_tendency_kernel_cuda!(ᶜYₜ, ᶠYₜ, _ᶜY, _ᶠY, p, t, zmax)
ᶜY_fv = Fields.field_values(_ᶜY)
ᶠY_fv = Fields.field_values(_ᶠY)
FT = Spaces.undertype(axes(_ᶜY))
ᶜNv = Spaces.nlevels(axes(_ᶜY))
ᶠNv = Spaces.nlevels(axes(_ᶠY))
ᶜus = DataLayouts.UniversalSize(ᶜY_fv)
ᶠus = DataLayouts.UniversalSize(ᶠY_fv)
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(ᶠus)
ᶜTS = DataLayouts.typesize(FT, eltype(ᶜY_fv))
ᶠTS = DataLayouts.typesize(FT, eltype(ᶠY_fv))
ᶜlg = Spaces.local_geometry_data(axes(_ᶜY))
ᶠlg = Spaces.local_geometry_data(axes(_ᶠY))
ᶜTS_lg = DataLayouts.typesize(FT, eltype(ᶜlg))
ᶜui = universal_index_cuda(ᶜus)
ᶠui = universal_index_cuda(ᶠus)
# ilc = @index(Local, Cartesian)
# igc = @index(Group, Cartesian)
# gs = @groupsize()
# ᶜui = universal_index_KA(ᶠus, ilc, igc, gs)
# ᶠui = universal_index_KA(ᶠus, ilc, igc, gs)
ᶜY_arr = CUDA.CuStaticSharedArray(FT, (ᶜNv, ᶜTS)) # ᶜY_arr = @localmem FT (ᶜNv, ᶜTS)
ᶠY_arr = CUDA.CuStaticSharedArray(FT, (ᶠNv, ᶠTS)) # ᶠY_arr = @localmem FT (ᶠNv, ᶠTS)
ᶜdata_col = rebuild_column(ᶜY_fv, ᶜY_arr)
ᶠdata_col = rebuild_column(ᶠY_fv, ᶠY_arr)
ᶜlg_arr = CUDA.CuStaticSharedArray(FT, (ᶜNv, ᶜTS_lg)) # ᶜlg_arr = @localmem FT (ᶜNv, ᶜTS_lg)
ᶠlg_arr = CUDA.CuStaticSharedArray(FT, (ᶠNv, ᶜTS_lg)) # ᶠlg_arr = @localmem FT (ᶠNv, ᶜTS_lg)
(ᶜspace_col, ᶠspace_col) = column_spaces(_ᶜY, _ᶠY, ᶠui, ᶜlg_arr, ᶠlg_arr)
is_valid_index(ᶜus, ᶜui) && (ᶜdata_col[ᶜui] = ᶜY_fv[ᶜui])
is_valid_index(ᶠus, ᶠui) && (ᶠdata_col[ᶠui] = ᶠY_fv[ᶠui])
ᶜlg_col = Spaces.local_geometry_data(ᶜspace_col)
ᶠlg_col = Spaces.local_geometry_data(ᶠspace_col)
# is_valid_index(ᶜus, ᶜui) && (ᶜlg_col[ᶜui] = ᶜlg[ᶜui])
# is_valid_index(ᶠus, ᶠui) && (ᶠlg_col[ᶠui] = ᶠlg[ᶠui])
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.coordinates.z[ᶜui] = ᶜlg.coordinates.z[ᶜui]) # needed
is_valid_index(ᶠus, ᶠui) && (ᶠlg_col.coordinates.z[ᶠui] = ᶠlg.coordinates.z[ᶠui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.J[ᶜui] = ᶜlg.J[ᶜui]) # needed
is_valid_index(ᶠus, ᶠui) && (ᶠlg_col.J[ᶠui] = ᶠlg.J[ᶠui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.invJ[ᶜui] = ᶜlg.invJ[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:1[ᶜui] = ᶜlg.gⁱʲ.components.data.:1[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:2[ᶜui] = ᶜlg.gⁱʲ.components.data.:2[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:3[ᶜui] = ᶜlg.gⁱʲ.components.data.:3[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:4[ᶜui] = ᶜlg.gⁱʲ.components.data.:4[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:5[ᶜui] = ᶜlg.gⁱʲ.components.data.:5[ᶜui]) # needed
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col.gⁱʲ.components.data.:6[ᶜui] = ᶜlg.gⁱʲ.components.data.:6[ᶜui]) # needed
is_valid_index(ᶠus, ᶠui) && (ᶠlg_col.gⁱʲ.components.data.:9[ᶠui] = ᶠlg.gⁱʲ.components.data.:9[ᶠui]) # needed
CUDA.sync_threads()
# ilc = @index(Local, Cartesian)
# igc = @index(Group, Cartesian)
# gs = @groupsize()
# ᶜui = universal_index_KA(ᶠus, ilc, igc, gs)
# ᶠui = universal_index_KA(ᶠus, ilc, igc, gs)
ᶜdata_col = rebuild_column(ᶜY_fv, ᶜY_arr)
ᶠdata_col = rebuild_column(ᶠY_fv, ᶠY_arr)
# (ᶜspace_col, ᶠspace_col) = column_spaces(_ᶜY, _ᶠY, ᶠui, ᶜlg_arr, ᶠlg_arr)
if is_valid_index(ᶜus, ᶜui)
(ᶜY, ᶠY) = column_states(_ᶜY, _ᶠY, ᶜdata_col, ᶠdata_col, ᶠui, ᶜspace_col, ᶠspace_col)
ᶜbc = ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(ᶜidx, ᶜhidx) = operator_inds(axes(ᶜY), ᶜui)
Fields.field_values(ᶜYₜ)[ᶜui] = Operators.getidx(axes(ᶜY), ᶜbc, ᶜidx, ᶜhidx)
# ᶜYₜ[ᶜui] = ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t)[ᶜui] # might be possible?
end
if is_valid_index(ᶠus, ᶠui)
(ᶜY, ᶠY) = column_states(_ᶜY, _ᶠY, ᶜdata_col, ᶠdata_col, ᶠui, ᶜspace_col, ᶠspace_col)
ᶠbc = ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(ᶠidx, ᶠhidx) = operator_inds(axes(ᶠY), ᶠui)
Fields.field_values(ᶠYₜ)[ᶠui] = Operators.getidx(axes(ᶠY), ᶠbc, ᶠidx, ᶠhidx)
# ᶠYₜ[ᶠui] = ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t)[ᶠui] # might be possible?
end
return nothing
end
@inline function universal_index_cuda(us)
(v,) = CUDA.threadIdx()
(h, ij) = CUDA.blockIdx()
(Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
Ni * Nj < ij && return CartesianIndex((-1, -1, 1, -1, -1))
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
return CartesianIndex((i, j, 1, v, h))
end
@inline function universal_index_KA(
us,
ilc, # @index(Local, Cartesian)
igc, # @index(Group, Cartesian)
gs, # @groupsize()
)
@inbounds begin
v = ilc[1]
(_, h, ij) = igc.I
# @print("igc = $(igc.I), ilc = $(ilc.I)\n")
(Ni, Nj, _, _, _) = DataLayouts.universal_size(us)
if Ni * Nj < ij
ui = CartesianIndex((-1, -1, 1, -1, -1))
else
@inbounds (i, j) = CartesianIndices((Ni, Nj))[ij].I
ui = CartesianIndex((i, j, 1, v, h))
end
return ui
end
end
function implicit_tendency_KA!(@nospecialize(Yₜ), @nospecialize(Y), @nospecialize(p), @nospecialize(t))
ᶜspace = axes(Y.c)
ᶠspace = Spaces.face_space(ᶜspace)
ᶠNv = Spaces.nlevels(ᶠspace)
ᶜcf = Fields.coordinate_field(ᶜspace)
us = DataLayouts.UniversalSize(Fields.field_values(ᶜcf))
(Ni, Nj, _, _, Nh) = DataLayouts.universal_size(us)
nitems = Ni * Nj * 1 * ᶠNv * Nh
ᶜYₜ = Yₜ.c
ᶠYₜ = Yₜ.f
ᶜY = Y.c
ᶠY = Y.f
(; rayleigh_sponge, params, dt) = p
p_kernel = (; rayleigh_sponge, params, dt)
zmax = Spaces.z_max(axes(ᶠY)) # DeviceIntervalTopology does not have mesh, and therefore cannot compute zmax
backend = if ClimaComms.device(ᶜspace) isa ClimaComms.CUDADevice
CUDABackend()
else
KA.CPU()
end
group = (ᶠNv, )
kernel = implicit_tendency_kernel_KA!(backend, group)
# ndrange = (Nh, 1, Ni * Nj)
ndrange = (ᶠNv, Nh, Ni * Nj)
# @show ᶠNv, Ni, Nj, Nh, prod(ndrange) # (ᶠNv, Ni, Nj, Nh, prod(ndrange)) = (9, 2, 2, 24, 864)
kernel(ᶜYₜ, ᶠYₜ, ᶜY, ᶠY, p_kernel, t, zmax, ndrange = ndrange)
end
@kernel function implicit_tendency_kernel_KA!(ᶜYₜ, ᶠYₜ, @Const(_ᶜY), @Const(_ᶠY), @Const(p), @Const(t), @Const(zmax))
ᶜY_fv = @uniform Fields.field_values(_ᶜY)
ᶠY_fv = @uniform Fields.field_values(_ᶠY)
FT = @uniform Spaces.undertype(axes(_ᶜY))
ᶜNv = @uniform Spaces.nlevels(axes(_ᶜY))
ᶠNv = @uniform Spaces.nlevels(axes(_ᶠY))
ᶜus = @uniform DataLayouts.UniversalSize(ᶜY_fv)
ᶠus = @uniform DataLayouts.UniversalSize(ᶠY_fv)
(Ni, Nj, _, _, Nh) = @uniform DataLayouts.universal_size(ᶠus)
ᶜTS = @uniform DataLayouts.typesize(FT, eltype(ᶜY_fv))
ᶠTS = @uniform DataLayouts.typesize(FT, eltype(ᶠY_fv))
ᶜlg = @uniform Spaces.local_geometry_data(axes(_ᶜY))
ᶠlg = @uniform Spaces.local_geometry_data(axes(_ᶠY))
ᶜTS_lg = @uniform DataLayouts.typesize(FT, eltype(ᶜlg))
ilc = @index(Local, Cartesian)
igc = @index(Group, Cartesian)
gs = @groupsize()
ᶜui = universal_index_KA(ᶠus, ilc, igc, gs)
ᶠui = universal_index_KA(ᶠus, ilc, igc, gs)
# @print("ᶜui = $(ᶜui.I)\n")
ᶜY_arr = @localmem FT (ᶜNv, ᶜTS)
ᶠY_arr = @localmem FT (ᶠNv, ᶠTS)
ᶜdata_col = rebuild_column(ᶜY_fv, ᶜY_arr)
ᶠdata_col = rebuild_column(ᶠY_fv, ᶠY_arr)
ᶜlg_arr = @localmem FT (ᶜNv, ᶜTS_lg)
ᶠlg_arr = @localmem FT (ᶠNv, ᶜTS_lg)
(ᶜspace_col, ᶠspace_col) = column_spaces(_ᶜY, _ᶠY, ᶠui, ᶜlg_arr, ᶠlg_arr)
is_valid_index(ᶜus, ᶜui) && (ᶜdata_col[ᶜui] = ᶜY_fv[ᶜui])
is_valid_index(ᶠus, ᶠui) && (ᶠdata_col[ᶠui] = ᶠY_fv[ᶠui])
ᶜlg_col = Spaces.local_geometry_data(ᶜspace_col)
ᶠlg_col = Spaces.local_geometry_data(ᶠspace_col)
is_valid_index(ᶜus, ᶜui) && (ᶜlg_col[ᶜui] = ᶜlg[ᶜui])
is_valid_index(ᶠus, ᶠui) && (ᶠlg_col[ᶠui] = ᶠlg[ᶠui])
@synchronize
ilc = @index(Local, Cartesian)
igc = @index(Group, Cartesian)
gs = @groupsize()
ᶜui = universal_index_KA(ᶠus, ilc, igc, gs)
ᶠui = universal_index_KA(ᶠus, ilc, igc, gs)
ᶜdata_col = rebuild_column(ᶜY_fv, ᶜY_arr)
ᶠdata_col = rebuild_column(ᶠY_fv, ᶠY_arr)
(ᶜspace_col, ᶠspace_col) = column_spaces(_ᶜY, _ᶠY, ᶠui, ᶜlg_arr, ᶠlg_arr)
if is_valid_index(ᶜus, ᶜui)
(ᶜY, ᶠY) = column_states(_ᶜY, _ᶠY, ᶜdata_col, ᶠdata_col, ᶠui, ᶜspace_col, ᶠspace_col)
ᶜbc = ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(ᶜidx, ᶜhidx) = operator_inds(axes(ᶜY), ᶜui)
Fields.field_values(ᶜYₜ)[ᶜui] = Operators.getidx(axes(ᶜY), ᶜbc, ᶜidx, ᶜhidx)
# ᶜYₜ[ᶜui] = ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t)[ᶜui] # might be possible?
end
if is_valid_index(ᶠus, ᶠui)
(ᶜY, ᶠY) = column_states(_ᶜY, _ᶠY, ᶜdata_col, ᶠdata_col, ᶠui, ᶜspace_col, ᶠspace_col)
ᶠbc = ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(ᶠidx, ᶠhidx) = operator_inds(axes(ᶠY), ᶠui)
Fields.field_values(ᶠYₜ)[ᶠui] = Operators.getidx(axes(ᶠY), ᶠbc, ᶠidx, ᶠhidx)
# ᶠYₜ[ᶠui] = ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t)[ᶠui] # might be possible?
end
end
@inline function operator_inds(space, I)
li = Operators.left_idx(space)
(i, j, _, v, h) = I.I
hidx = (i, j, h)
idx = v - 1 + li
return (idx, hidx)
end
@inline cartesian_indices(field::Fields.Field) =
cartesian_indices(Fields.field_values(field))
@inline cartesian_indices(data::DataLayouts.AbstractData) =
cartesian_indices(DataLayouts.UniversalSize(data))
@inline cartesian_indices(us::DataLayouts.UniversalSize) =
CartesianIndices(map(Base.OneTo, DataLayouts.universal_size(us)))
@inline universal_index(x) = cartesian_indices(x)
function thermo_state(thermo_params, ᶜρ, ᶜρe_tot, ᶜK, grav, ᶜz)
return @. lazy(TD.PhaseDry_ρe(
thermo_params,
ᶜρ,
ᶜρe_tot / ᶜρ - ᶜK - Φ(grav, ᶜz),
))
end
# Drop everything except Nv and S:
@inline column_type_params(data::DataLayouts.AbstractData) = column_type_params(typeof(data))
@inline column_type_params(::Type{DataLayouts.IJFH{S, Nij, A}}) where {S, Nij, A} = (S, )
@inline column_type_params(::Type{DataLayouts.IJHF{S, Nij, A}}) where {S, Nij, A} = (S, )
@inline column_type_params(::Type{DataLayouts.IFH{S, Ni, A}}) where {S, Ni, A} = (S, )
@inline column_type_params(::Type{DataLayouts.IHF{S, Ni, A}}) where {S, Ni, A} = (S, )
@inline column_type_params(::Type{DataLayouts.DataF{S, A}}) where {S, A} = (S,)
@inline column_type_params(::Type{DataLayouts.IJF{S, Nij, A}}) where {S, Nij, A} = (S, )
@inline column_type_params(::Type{DataLayouts.IF{S, Ni, A}}) where {S, Ni, A} = (S, )
@inline column_type_params(::Type{DataLayouts.VF{S, Nv, A}}) where {S, Nv, A} = (S, Nv)
@inline column_type_params(::Type{DataLayouts.VIJFH{S, Nv, Nij, A}}) where {S, Nv, Nij, A} = (S, Nv)
@inline column_type_params(::Type{DataLayouts.VIJHF{S, Nv, Nij, A}}) where {S, Nv, Nij, A} = (S, Nv)
@inline column_type_params(::Type{DataLayouts.VIFH{S, Nv, Ni, A}}) where {S, Nv, Ni, A} = (S, Nv)
@inline column_type_params(::Type{DataLayouts.VIHF{S, Nv, Ni, A}}) where {S, Nv, Ni, A} = (S, Nv)
# Drop everything except V and F:
@inline column_singleton(::DataLayouts.IJFH) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.IJHF) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.IFH) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.IHF) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.DataF) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.IJF) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.IF) = DataLayouts.DataFSingleton()
@inline column_singleton(::DataLayouts.VF) = DataLayouts.VFSingleton()
@inline column_singleton(::DataLayouts.VIJFH) = DataLayouts.VFSingleton()
@inline column_singleton(::DataLayouts.VIJHF) = DataLayouts.VFSingleton()
@inline column_singleton(::DataLayouts.VIFH) = DataLayouts.VFSingleton()
@inline column_singleton(::DataLayouts.VIHF) = DataLayouts.VFSingleton()
function rebuild_column(data, array::AbstractArray)
s_column = column_singleton(data)
return DataLayouts.union_all(s_column){column_type_params(data)...}(array)
end
function column_lg_shmem(f, ui)
(i, j, _, _, h) = ui.I
colidx = Grids.ColumnIndex((i, j), h)
lg = Spaces.local_geometry_data(axes(f))
lg_col = Spaces.column(lg, colidx)
FT = Spaces.undertype(axes(f))
Nv = Spaces.nlevels(axes(f))
TS = DataLayouts.typesize(FT, eltype(lg_col))
lg_arr = CUDA.CuStaticSharedArray(FT, (Nv, TS))
return rebuild_column(lg_col, lg_arr)
end
function column_lg_shmem_KA(f, ui, lg_arr)
(i, j, _, _, h) = ui.I
colidx = Grids.ColumnIndex((i, j), h)
lg = Spaces.local_geometry_data(axes(f))
lg_col = Spaces.column(lg, colidx)
return rebuild_column(lg_col, lg_arr)
end
function column_spaces(ᶜY, ᶠY, ui, ᶜlg_arr, ᶠlg_arr)
(i, j, _, _, h) = ui.I
colidx = Grids.ColumnIndex((i, j), h)
ᶜlg_col = column_lg_shmem_KA(ᶜY, ui, ᶜlg_arr)
ᶠlg_col = column_lg_shmem_KA(ᶠY, ui, ᶠlg_arr)
col_space = Spaces.column(axes(ᶜY), colidx)
col_grid = Spaces.grid(col_space)
if col_grid isa Grids.ColumnGrid && col_grid.full_grid isa Grids.DeviceExtrudedFiniteDifferenceGrid
(; full_grid) = col_grid
(; vertical_topology, global_geometry) = full_grid
col_grid_shmem = Grids.DeviceFiniteDifferenceGrid(vertical_topology, global_geometry, ᶜlg_col, ᶠlg_col)
ᶜspace_col = Spaces.space(col_grid_shmem, Grids.CellCenter())
ᶠspace_col = Spaces.space(col_grid_shmem, Grids.CellFace())
elseif col_grid isa Grids.ColumnGrid && col_grid.full_grid isa Grids.ExtrudedFiniteDifferenceGrid
(; full_grid) = col_grid
(; vertical_grid, global_geometry) = full_grid
col_grid_shmem = Grids.FiniteDifferenceGrid(vertical_grid.topology, global_geometry, ᶜlg_col, ᶠlg_col)
ᶜspace_col = Spaces.space(col_grid_shmem, Grids.CellCenter())
ᶠspace_col = Spaces.space(col_grid_shmem, Grids.CellFace())
else
error("Uncaught case")
end
return (ᶜspace_col, ᶠspace_col)
end
function column_states(ᶜY, ᶠY, ᶜdata_col, ᶠdata_col, ui, ᶜspace_col, ᶠspace_col)
ᶜY_col = Fields.Field(ᶜdata_col, ᶜspace_col)
ᶠY_col = Fields.Field(ᶠdata_col, ᶠspace_col)
return (ᶜY_col, ᶠY_col)
end
function vindex()
(tv,) = CUDA.threadIdx()
(h, bv, ij) = CUDA.blockIdx()
v = tv + (bv - 1) * CUDA.blockDim().x
return v
end
function ᶜimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(; rayleigh_sponge, params, dt) = p
ᶜz = Fields.coordinate_field(ᶜY).z
ᶜJ = Fields.local_geometry_field(ᶜY).J
ᶠz = Fields.coordinate_field(ᶠY).z
FT = Spaces.undertype(axes(ᶜY))
grav = FT(CAP.grav(params))
thermo_params = CAP.thermodynamics_params(params)
ᶜρ = ᶜY.ρ
ᶜρe_tot = ᶜY.ρe_tot
ᶜuₕ = ᶜY.uₕ
ᶠu₃ = ᶠY.u₃
ᶜK = CA.compute_kinetic(ᶜuₕ, ᶠu₃)
ᶜts = thermo_state(thermo_params, ᶜρ, ᶜρe_tot, ᶜK, grav, ᶜz)
ᶜp = @. lazy(TD.air_pressure(thermo_params, ᶜts))
ᶜh_tot = @. lazy(TD.total_specific_enthalpy(thermo_params, ᶜts, ᶜρe_tot / ᶜρ))
# Central advection of active tracers (e_tot and q_tot)
ᶠuₕ³ = @. lazy(ᶠwinterp(ᶜρ * ᶜJ, CT3(ᶜuₕ)))
ᶠu³ = @. lazy(ᶠuₕ³ + CT3(ᶠu₃))
tend_ρ_1 = @. lazy( - ᶜdivᵥ(ᶠwinterp(ᶜJ, ᶜρ) * ᶠuₕ³))
tend_ρe_tot_1 = CA.vertical_transport(ᶜρ, ᶠu³, ᶜh_tot, dt, Val(:none))
ᶜuₕ₀ = (zero(eltype(ᶜuₕ)),)
return @. lazy(ᶜtendencies(
tend_ρ_1,
- ᶜuₕ₀,
tend_ρe_tot_1,
))
end
function ᶠimplicit_tendency_bc(ᶜY, ᶠY, p, t, zmax)
(; rayleigh_sponge, params) = p
ᶜz = Fields.coordinate_field(ᶜY).z
ᶠz = Fields.coordinate_field(ᶠY).z
FT = Spaces.undertype(axes(ᶜY))
grav = FT(CAP.grav(params))
thermo_params = CAP.thermodynamics_params(params)
ᶜρ = ᶜY.ρ
ᶜρe_tot = ᶜY.ρe_tot
ᶜuₕ = ᶜY.uₕ
ᶠu₃ = ᶠY.u₃
ᶜK = CA.compute_kinetic(ᶜuₕ, ᶠu₃)
ᶜts = thermo_state(thermo_params, ᶜρ, ᶜρe_tot, ᶜK, grav, ᶜz)
ᶜp = @. lazy(TD.air_pressure(thermo_params, ᶜts))
bc1 = @. lazy(- (ᶠgradᵥ(ᶜp) / ᶠinterp(ᶜρ) + ᶠgradᵥ(Φ(grav, ᶜz))))
bc2 = @. lazy(- CA.β_rayleigh_w(rayleigh_sponge, ᶠz, zmax) * ᶠu₃)
return @. lazy(ᶠtendencies(bc1 + bc2))
end
function ImplicitEquationJacobian(
Y::Fields.FieldVector;
approximate_solve_iters = 1,
transform_flag = false,
)
FT = Spaces.undertype(axes(Y.c))
CTh = CA.CTh_vector_type(axes(Y.c))
BidiagonalRow_C3 = MatrixFields.BidiagonalMatrixRow{CA.C3{FT}}
BidiagonalRow_ACT3 =
MatrixFields.BidiagonalMatrixRow{LA.Adjoint{FT, CA.CT3{FT}}}
BidiagonalRow_C3xACTh = MatrixFields.BidiagonalMatrixRow{
typeof(zero(CA.C3{FT}) * zero(CTh{FT})'),
}
TridiagonalRow_C3xACT3 = MatrixFields.TridiagonalMatrixRow{
typeof(zero(CA.C3{FT}) * zero(CA.CT3{FT})'),
}
is_in_Y(name) = MatrixFields.has_field(Y, name)
sfc_if_available = is_in_Y(@name(sfc)) ? (@name(sfc),) : ()
# Note: We have to use FT(-1) * I instead of -I because inv(-1) == -1.0,
# which means that multiplying inv(-1) by a Float32 will yield a Float64.
identity_blocks = MatrixFields.unrolled_map(
name -> (name, name) => FT(-1) * LA.I,
(@name(c.ρ), sfc_if_available...),
)
active_scalar_names = (@name(c.ρ), @name(c.ρe_tot))
advection_blocks = (
MatrixFields.unrolled_map(
name -> (name, @name(f.u₃)) => similar(Y.c, BidiagonalRow_ACT3),
active_scalar_names,
)...,
MatrixFields.unrolled_map(
name -> (@name(f.u₃), name) => similar(Y.f, BidiagonalRow_C3),
active_scalar_names,
)...,
(@name(f.u₃), @name(c.uₕ)) => similar(Y.f, BidiagonalRow_C3xACTh),
(@name(f.u₃), @name(f.u₃)) => similar(Y.f, TridiagonalRow_C3xACT3),
)
diffused_scalar_names = (@name(c.ρe_tot),)
diffusion_blocks = MatrixFields.unrolled_map(
name -> (name, name) => FT(-1) * LA.I,
(diffused_scalar_names..., @name(c.uₕ)),
)
matrix = MatrixFields.FieldMatrix(
identity_blocks...,
advection_blocks...,
diffusion_blocks...,
)
names₁_group₁ = (@name(c.ρ), sfc_if_available...)
names₁_group₃ = (@name(c.ρe_tot),)
names₁ = (names₁_group₁..., names₁_group₃...)
alg₂ = MatrixFields.BlockLowerTriangularSolve(@name(c.uₕ))
alg = MatrixFields.BlockArrowheadSolve(names₁...; alg₂)
return CA.ImplicitEquationJacobian(
matrix,
MatrixFields.FieldMatrixSolver(alg, matrix, Y),
CA.IgnoreDerivative(), # diffusion_flag
CA.IgnoreDerivative(), # topography_flag
CA.IgnoreDerivative(), # sgs_advection_flag
CA.IgnoreDerivative(), # sgs_entr_detr_flag
CA.IgnoreDerivative(), # sgs_nh_pressure_flag
CA.IgnoreDerivative(), # sgs_mass_flux_flag
similar(Y),
similar(Y),
transform_flag,
Ref{FT}(),
)
end
function Wfact!(A, Y, p, dtγ, t)
FT = Spaces.undertype(axes(Y.c))
dtγ′ = FT(float(dtγ))
A.dtγ_ref[] = dtγ′
update_implicit_equation_jacobian!(A, Y, p, dtγ′)
end
Φ(grav, z) = grav * z
function update_implicit_equation_jacobian!(A, Y, p, dtγ)
(; matrix) = A
(; ᶜK, ᶜts, ᶜp, ᶜh_tot) = p.precomputed
(; ∂ᶜK_∂ᶜuₕ, ∂ᶜK_∂ᶠu₃, ᶠp_grad_matrix, ᶜadvection_matrix) = p
(; params) = p
FT = Spaces.undertype(axes(Y.c))
CTh = CA.CTh_vector_type(axes(Y.c))
one_C3xACT3 = C3(FT(1)) * CT3(FT(1))'
rs = p.rayleigh_sponge
ᶠz = Fields.coordinate_field(Y.f).z
zmax = CA.z_max(axes(Y.f))
T_0 = FT(CAP.T_0(params))
cp_d = FT(CAP.cp_d(params))
thermo_params = CAP.thermodynamics_params(params)
ᶜz = Fields.coordinate_field(Y.c).z
grav = FT(CAP.grav(params))
ᶜρ = Y.c.ρ
ᶜuₕ = Y.c.uₕ
ᶠu₃ = Y.f.u₃
ᶜJ = Fields.local_geometry_field(Y.c).J
ᶠgⁱʲ = Fields.local_geometry_field(Y.f).gⁱʲ
ᶜkappa_m = p.ᶜtemp_scalar
@. ᶜkappa_m =
TD.gas_constant_air(thermo_params, ᶜts) / TD.cv_m(thermo_params, ᶜts)
@. ∂ᶜK_∂ᶜuₕ = DiagonalMatrixRow(adjoint(CTh(ᶜuₕ)))
@. ∂ᶜK_∂ᶠu₃ =
ᶜinterp_matrix() ⋅ DiagonalMatrixRow(adjoint(CT3(ᶠu₃))) +
DiagonalMatrixRow(adjoint(CT3(ᶜuₕ))) ⋅ ᶜinterp_matrix()
@. ᶠp_grad_matrix = DiagonalMatrixRow(-1 / ᶠinterp(ᶜρ)) ⋅ ᶠgradᵥ_matrix()
@. ᶜadvection_matrix =
-(ᶜadvdivᵥ_matrix()) ⋅ DiagonalMatrixRow(ᶠwinterp(ᶜJ, ᶜρ))
∂ᶜρ_err_∂ᶠu₃ = matrix[@name(c.ρ), @name(f.u₃)]
@. ∂ᶜρ_err_∂ᶠu₃ = dtγ * ᶜadvection_matrix ⋅ DiagonalMatrixRow(CA.g³³(ᶠgⁱʲ))
∂ᶜρχ_err_∂ᶠu₃ = matrix[@name(c.ρe_tot), @name(f.u₃)]
@. ∂ᶜρχ_err_∂ᶠu₃ =
dtγ * ᶜadvection_matrix ⋅
DiagonalMatrixRow(ᶠinterp(ᶜh_tot) * CA.g³³(ᶠgⁱʲ))
∂ᶠu₃_err_∂ᶜρ = matrix[@name(f.u₃), @name(c.ρ)]
∂ᶠu₃_err_∂ᶜρe_tot = matrix[@name(f.u₃), @name(c.ρe_tot)]
@. ∂ᶠu₃_err_∂ᶜρ =
dtγ * (
ᶠp_grad_matrix ⋅
DiagonalMatrixRow(ᶜkappa_m * (T_0 * cp_d - ᶜK - Φ(grav, ᶜz))) +
DiagonalMatrixRow(ᶠgradᵥ(ᶜp) / abs2(ᶠinterp(ᶜρ))) ⋅
ᶠinterp_matrix()
)
@. ∂ᶠu₃_err_∂ᶜρe_tot = dtγ * ᶠp_grad_matrix ⋅ DiagonalMatrixRow(ᶜkappa_m)
∂ᶠu₃_err_∂ᶜuₕ = matrix[@name(f.u₃), @name(c.uₕ)]
∂ᶠu₃_err_∂ᶠu₃ = matrix[@name(f.u₃), @name(f.u₃)]
I_u₃ = DiagonalMatrixRow(one_C3xACT3)
@. ∂ᶠu₃_err_∂ᶜuₕ =
dtγ * ᶠp_grad_matrix ⋅ DiagonalMatrixRow(-(ᶜkappa_m) * ᶜρ) ⋅ ∂ᶜK_∂ᶜuₕ
@. ∂ᶠu₃_err_∂ᶠu₃ =
dtγ * (
ᶠp_grad_matrix ⋅ DiagonalMatrixRow(-(ᶜkappa_m) * ᶜρ) ⋅ ∂ᶜK_∂ᶠu₃ +
DiagonalMatrixRow(-CA.β_rayleigh_w(rs, ᶠz, zmax) * (one_C3xACT3,))
) - (I_u₃,)
end
function set_precomputed_quantities!(Y, p, t)
thermo_params = CAP.thermodynamics_params(p.params)
(; ᶜu, ᶠu³, ᶠu, ᶜK, ᶜts, ᶜp) = p.precomputed
ᶜρ = Y.c.ρ
ᶜuₕ = Y.c.uₕ
ᶜz = Fields.coordinate_field(Y.c).z
grav = FT(CAP.grav(params))
ᶠu₃ = Y.f.u₃
@. ᶜu = C123(ᶜuₕ) + ᶜinterp(C123(ᶠu₃))
ᶠu³ .= CA.compute_ᶠuₕ³(ᶜuₕ, ᶜρ) .+ CT3.(ᶠu₃)
ᶜK .= CA.compute_kinetic(ᶜuₕ, ᶠu₃)
@. ᶜts = TD.PhaseDry_ρe(
thermo_params,
Y.c.ρ,
Y.c.ρe_tot / Y.c.ρ - ᶜK - Φ(grav, ᶜz),
)
@. ᶜp = TD.air_pressure(thermo_params, ᶜts)
(; ᶜh_tot) = p.precomputed
@. ᶜh_tot =
TD.total_specific_enthalpy(thermo_params, ᶜts, Y.c.ρe_tot / Y.c.ρ)
return nothing
end
function dss!(Y, p, t)
Spaces.weighted_dss!(Y.c => p.ghost_buffer.c, Y.f => p.ghost_buffer.f)
return nothing
end
function remaining_tendency!(Yₜ, Yₜ_lim, Y, p, t)
# Yₜ_lim .= zero(eltype(Yₜ_lim))
Yₜ .= zero(eltype(Yₜ))
(; dt, params, rayleigh_sponge) = p
(; ᶜh_tot) = p.precomputed
(; ᶠu³, ᶜu, ᶜK, ᶜp) = p.precomputed
(; ᶜf³, ᶠf¹²) = p.precomputed
ᶜz = Fields.coordinate_field(Y.c).z
ᶜJ = Fields.local_geometry_field(Y.c).J
grav = FT(CAP.grav(params))
ᶜuₕ = Y.c.uₕ
ᶠu₃ = Y.f.u₃
ᶜρ = Y.c.ρ
@. Yₜ.c.ρ -= wdivₕ(ᶜρ * ᶜu)
@. Yₜ.c.ρe_tot -= wdivₕ(ᶜρ * ᶜh_tot * ᶜu)
@. Yₜ.c.uₕ -= C12(gradₕ(ᶜp) / ᶜρ + gradₕ(ᶜK + Φ(grav, ᶜz)))
ᶜω³ = p.scratch.ᶜtemp_CT3
ᶠω¹² = p.scratch.ᶠtemp_CT12
point_type = eltype(Fields.coordinate_field(Y.c))
if point_type <: Geometry.Abstract3DPoint
@. ᶜω³ = curlₕ(ᶜuₕ)
elseif point_type <: Geometry.Abstract2DPoint
@. ᶜω³ = zero(ᶜω³)
end
@. ᶠω¹² = ᶠcurlᵥ(ᶜuₕ)
@. ᶠω¹² += CT12(curlₕ(ᶠu₃))
# Without the CT12(), the right-hand side would be a CT1 or CT2 in 2D space.
ᶠω¹²′ = if isnothing(ᶠf¹²)
ᶠω¹² # shallow atmosphere
else
@. lazy(ᶠf¹² + ᶠω¹²) # deep atmosphere
end
@. Yₜ.c.uₕ -=
ᶜinterp(ᶠω¹²′ × (ᶠinterp(ᶜρ * ᶜJ) * ᶠu³)) / (ᶜρ * ᶜJ) +
(ᶜf³ + ᶜω³) × CT12(ᶜu)
@. Yₜ.f.u₃ -= ᶠω¹²′ × ᶠinterp(CT12(ᶜu)) + ᶠgradᵥ(ᶜK)
Yₜ.c.uₕ .+= CA.rayleigh_sponge_tendency_uₕ(ᶜuₕ, rayleigh_sponge)
return Yₜ
end
# This block:
# @time if !@isdefined(integrator)
FT = Float64;
if high_res
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 63,
z_min = 0,
z_max = 30000.0,
radius = 6.371e6,
h_elem = 30,
n_quad_points = 4,
staggering = CellCenter(),
);
else
ᶜspace = ExtrudedCubedSphereSpace(
FT;
z_elem = 8,
z_min = 0,
z_max = 30000.0,
radius = 6.371e6,
h_elem = 2,
n_quad_points = 2,
staggering = CellCenter(),
);
end
ᶠspace = Spaces.face_space(ᶜspace);
cnt = (; ρ = zero(FT), uₕ = zero(CA.C12{FT}), ρe_tot = zero(FT));
Yc = Fields.fill(cnt, ᶜspace);
fill!(parent(Yc.ρ), 1)
fill!(parent(Yc.uₕ), 0.01)
fill!(parent(Yc.ρe_tot), 1000.0)
Yf = Fields.fill((; u₃ = zero(CA.C3{FT})), ᶠspace);
Y = Fields.FieldVector(; c = Yc, f = Yf);
A = ImplicitEquationJacobian(
Y;
approximate_solve_iters = 2,
transform_flag = false, # assumes use_transform returns false
)
implicit_func = SciMLBase.ODEFunction(
implicit_tendency!;
jac_prototype = A,
Wfact = Wfact!, # assumes use_transform returns false
tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= 0),
)
func = CTS.ClimaODEFunction(;
T_exp_T_lim! = remaining_tendency!,
T_imp! = implicit_func,
# Can we just pass implicit_tendency! and jac_prototype etc.?
lim! = (Y, p, t, ref_Y) -> nothing, # limiters_func!
dss!,
cache! = set_precomputed_quantities!,
cache_imp! = set_precomputed_quantities!,
)
newtons_method = CTS.NewtonsMethod(; max_iters = 2)
params = CA.ClimaAtmosParameters(FT)
ᶠcoord = Fields.coordinate_field(ᶠspace);
ᶜcoord = Fields.coordinate_field(ᶜspace);
(; ᶜf³, ᶠf¹²) = CA.compute_coriolis(ᶜcoord, ᶠcoord, params);
scratch = (;
ᶜtemp_CT3 = Fields.Field(CT3{FT}, ᶜspace),
ᶠtemp_CT12 = Fields.Field(CT12{FT}, ᶠspace),
)
precomputed = (;
ᶜh_tot = Fields.Field(FT, ᶜspace),
ᶠu³ = Fields.Field(CA.CT3{FT}, ᶠspace),
ᶜf³,
ᶠf¹²,
ᶜp = Fields.Field(FT, ᶜspace),
ᶜK = Fields.Field(FT, ᶜspace),
ᶜts = Fields.Field(TD.PhaseDry{FT}, ᶜspace),
ᶠu = Fields.Field(C123{FT}, ᶠspace),
ᶜu = Fields.Field(C123{FT}, ᶜspace),
)
dt = FT(0.1)
ghost_buffer =
!CA.do_dss(axes(Y.c)) ? (;) :
(; c = Spaces.create_dss_buffer(Y.c), f = Spaces.create_dss_buffer(Y.f))
CTh = CA.CTh_vector_type(axes(Y.c))
p = (;
rayleigh_sponge = CA.RayleighSponge{FT}(;
zd = params.zd_rayleigh,
α_uₕ = params.alpha_rayleigh_uh,
α_w = params.alpha_rayleigh_w,
),
params,
∂ᶜK_∂ᶜuₕ = Fields.Field(DiagonalMatrixRow{Adjoint{FT, CTh{FT}}}, ᶜspace),
∂ᶜK_∂ᶠu₃ = Fields.Field(BidiagonalMatrixRow{Adjoint{FT, CT3{FT}}}, ᶜspace),
ᶜadvection_matrix = Fields.Field(
BidiagonalMatrixRow{Adjoint{FT, C3{FT}}},
ᶜspace,
),
ᶜtemp_scalar = Fields.Field(FT, ᶜspace),
ᶠp_grad_matrix = Fields.Field(BidiagonalMatrixRow{C3{FT}}, ᶠspace),
scratch,
ghost_buffer,
dt,
precomputed,
)
ode_algo = CTS.IMEXAlgorithm(CTS.ARS343(), newtons_method)
problem = SciMLBase.ODEProblem(func, Y, (FT(0), FT(1)), p)
integrator = SciMLBase.init(problem, ode_algo; dt)
Yₜ = similar(integrator.u);
# end
function main!(integrator, Yₜ, n)
for _ in 1:n
# @time SciMLBase.step!(integrator)
@time implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
end
return nothing
end
using Test
if ClimaComms.device() isa ClimaComms.CUDADevice
Yₜ_bc = similar(Yₜ);
@. Yₜ_bc = 0
@. Yₜ = 0
Yc = integrator.u.c;
Yf = integrator.u.f;
fill!(parent(Yc.ρ), 1);
zc = Fields.coordinate_field(Yc).z;
zf = Fields.coordinate_field(Yf).z;
@. Yc.ρ += 0.1*sin(zc);
parent(Yf.u₃) .+= 0.001 .* sin.(parent(zf));
fill!(parent(Yc.uₕ), 0.01);
fill!(parent(Yc.ρe_tot), 100000.0);
implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
implicit_tendency_bc!(Yₜ_bc, integrator.u, integrator.p, integrator.t)
abs_err_c = maximum(Array(abs.(parent(Yₜ.c) .- parent(Yₜ_bc.c))))
abs_err_f = maximum(Array(abs.(parent(Yₜ.f) .- parent(Yₜ_bc.f))))
results_match = abs_err_c < 6e-9 && abs_err_c < 6e-9
if !results_match
@show norm(Array(parent(Yₜ_bc.c))), norm(Array(parent(Yₜ.c)))
@show norm(Array(parent(Yₜ_bc.f))), norm(Array(parent(Yₜ.f)))
@show abs_err_c
@show abs_err_f
end
@test results_match
println(CUDA.@profile trace=true begin
# SciMLBase.step!(integrator)
implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
implicit_tendency!(Yₜ, integrator.u, integrator.p, integrator.t)
end)
println(CUDA.@profile begin
# SciMLBase.step!(integrator)
@. Yₜ += 1
@. Yₜ += 1
@. Yₜ += 1
@. Yₜ += 1
end)
else
@info "Compiling main loop"
@time main!(integrator, Yₜ, 1)
@info "Running main loop"
@time main!(integrator, Yₜ, 3)
end
nothing