GPUCompiler.jl
GPUCompiler.jl copied to clipboard
GPUCompiler code typed is type unstable whereas regular code typed is fine
function func_mixed_call(N)
allargs = Expr[]
typeargs = Union{Symbol,Expr}[]
exprs2 = Union{Symbol,Expr}[]
for i in 1:N
arg = Symbol("arg_$i")
targ = Symbol("T$i")
e = :($arg::$targ)
push!(allargs, e)
push!(typeargs, targ)
inarg = quote
if RefTypes[1+$i]
$arg[]
else
$arg
end
end
push!(exprs2, inarg)
end
quote
@generated function runtime_mixed_call(::Val{RefTypes}, f::F, $(allargs...)) where {RefTypes, F, $(typeargs...)}
fexpr = :f
if RefTypes[1]
fexpr = :(($fexpr)[])
end
exprs2 = Union{Symbol,Expr}[]
for i in 1:$N
arg = Symbol("arg_$i")
inarg = if RefTypes[1+i]
:($arg[])
else
:($arg)
end
push!(exprs2, inarg)
end
return quote
Base.@_inline_meta
$fexpr($(exprs2...))
end
end
end
end
for N in 0:10
eval(func_mixed_call(N))
end
function make(x, y, z)
function inner(); for i in z x[i] = y; end
end
end
m = make(ones(10), 3.0, 1:3)
function threading_run(func)
for i = 1:10
func()
end
end
using GPUCompiler
Base.@kwdef struct TestTarget <: AbstractCompilerTarget
end
GPUCompiler.llvm_triple(::TestTarget) = Sys.MACHINE
struct TestCompilerParams<: AbstractCompilerParams
end
# TODO: We shouldn't blanket opt-out
# GPUCompiler.check_invocation(job::CompilerJob{TestTarget}, entry::LLVM.Function) = nothing
GPUCompiler.runtime_slug(job::CompilerJob{TestTarget}) = "enzyme"
@inline function fspec(@nospecialize(F), @nospecialize(TT), world::Integer)
# primal function. Inferred here to get return type
_tt = (TT.parameters...,)
primal_tt = Tuple{_tt...} # map(eltype, _tt)...}
primal = GPUCompiler.methodinstance(F, primal_tt, world)
return primal
end
function get_job(@nospecialize(func), @nospecialize(tt))
world = Base.get_world_counter()
primal = fspec(Core.Typeof(func), tt, world)
target = TestTarget()
params = TestCompilerParams()
return GPUCompiler.CompilerJob(primal, CompilerConfig(target, params; kernel=false), world)
end
function enzyme_code_typed(@nospecialize(func), @nospecialize(types); kwargs...)
job = get_job(func, types; kwargs...)
GPUCompiler.code_typed(job; kwargs...)
end
@show enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})
using InteractiveUtils
@show @code_typed runtime_mixed_call(Val((false,true)), threading_run, Ref(m))
On 1.10 output is
wmoses@beast:~/git/Enzyme.jl (cai) $ ./julia-1.10.2/bin/julia --project
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.10.2 (2024-03-01)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia> include("sad.jl")
enzyme_code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}}) = Any[CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└── goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}})
│ %4 = Base.getfield(%3, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└── goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└── goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│ (f)(%8)::Nothing
└── return nothing
) => Nothing]
#= /home/wmoses/git/Enzyme.jl/sad.jl:102 =# @code_typed(runtime_mixed_call(Val((false, true)), threading_run, Ref(m))) = CodeInfo(
1 ── %1 = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└─── goto #17 if not true
2 ┄─ %3 = φ (#1 => 1, #16 => %41)::Int64
│ %4 = Core.getfield(%1, :z)::UnitRange{Int64}
│ %5 = Base.getfield(%4, :start)::Int64
│ %6 = Base.getfield(%4, :stop)::Int64
│ %7 = Base.slt_int(%6, %5)::Bool
└─── goto #4 if not %7
3 ── goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│ %11 = Base.getfield(%4, :start)::Int64
└─── goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│ %14 = φ (#4 => %10)::Int64
│ %15 = φ (#4 => %11)::Int64
│ %16 = Base.not_int(%13)::Bool
└─── goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│ %19 = φ (#5 => %15, #10 => %30)::Int64
│ %20 = Core.getfield(%1, :x)::Vector{Float64}
│ %21 = Core.getfield(%1, :y)::Float64
│ Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│ %23 = Base.getfield(%4, :stop)::Int64
│ %24 = (%19 === %23)::Bool
└─── goto #8 if not %24
7 ── goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└─── goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│ %30 = φ (#8 => %27)::Int64
│ %31 = φ (#7 => true, #8 => false)::Bool
│ %32 = Base.not_int(%31)::Bool
└─── goto #11 if not %32
10 ─ goto #6
11 ┄ goto #12
12 ─ %36 = (%3 === 10)::Bool
└─── goto #14 if not %36
13 ─ goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└─── goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│ %42 = φ (#13 => true, #14 => false)::Bool
│ %43 = Base.not_int(%42)::Bool
└─── goto #17 if not %43
16 ─ goto #2
17 ┄ goto #18
18 ─ return nothing
) => Nothing
CodeInfo(
1 ── %1 = Base.getfield(arg_1, :x)::var"#inner#12"{Vector{Float64}, Float64, UnitRange{Int64}}
└─── goto #17 if not true
2 ┄─ %3 = φ (#1 => 1, #16 => %41)::Int64
│ %4 = Core.getfield(%1, :z)::UnitRange{Int64}
│ %5 = Base.getfield(%4, :start)::Int64
│ %6 = Base.getfield(%4, :stop)::Int64
│ %7 = Base.slt_int(%6, %5)::Bool
└─── goto #4 if not %7
3 ── goto #5
4 ── %10 = Base.getfield(%4, :start)::Int64
│ %11 = Base.getfield(%4, :start)::Int64
└─── goto #5
5 ┄─ %13 = φ (#3 => true, #4 => false)::Bool
│ %14 = φ (#4 => %10)::Int64
│ %15 = φ (#4 => %11)::Int64
│ %16 = Base.not_int(%13)::Bool
└─── goto #11 if not %16
6 ┄─ %18 = φ (#5 => %14, #10 => %29)::Int64
│ %19 = φ (#5 => %15, #10 => %30)::Int64
│ %20 = Core.getfield(%1, :x)::Vector{Float64}
│ %21 = Core.getfield(%1, :y)::Float64
│ Base.arrayset(true, %20, %21, %18)::Vector{Float64}
│ %23 = Base.getfield(%4, :stop)::Int64
│ %24 = (%19 === %23)::Bool
└─── goto #8 if not %24
7 ── goto #9
8 ── %27 = Base.add_int(%19, 1)::Int64
└─── goto #9
9 ┄─ %29 = φ (#8 => %27)::Int64
│ %30 = φ (#8 => %27)::Int64
│ %31 = φ (#7 => true, #8 => false)::Bool
│ %32 = Base.not_int(%31)::Bool
└─── goto #11 if not %32
10 ─ goto #6
11 ┄ goto #12
12 ─ %36 = (%3 === 10)::Bool
└─── goto #14 if not %36
13 ─ goto #15
14 ─ %39 = Base.add_int(%3, 1)::Int64
└─── goto #15
15 ┄ %41 = φ (#14 => %39)::Int64
│ %42 = φ (#13 => true, #14 => false)::Bool
│ %43 = Base.not_int(%42)::Bool
└─── goto #17 if not %43
16 ─ goto #2
17 ┄ goto #18
18 ─ return nothing
) => Nothing
cc @vchuravy
wmoses@beast:~/git/GPUCompiler.jl ((HEAD detached at origin/master)) $ git log
commit 8b513be9e2230fe0dd1905b805e25fa049b24d1d (HEAD, tag: v0.26.5, origin/master, origin/HEAD)
Author: Tim Besard <[email protected]>
Date: Fri May 24 10:25:09 2024 +0200
Bump version.
julia-repl> @show code_typed(runtime_mixed_call, Tuple{Val{(false, true)}, typeof(threading_run), Ref{typeof(m)}})
1-element Vector{Any}:
CodeInfo(
1 ─ %1 = (isa)(arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})::Bool
└── goto #3 if not %1
2 ─ %3 = π (arg_1, Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}})
│ %4 = Base.getfield(%3, :x)::var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}
└── goto #4
3 ─ %6 = Base.getindex(arg_1)::Any
└── goto #4
4 ┄ %8 = φ (#2 => %4, #3 => %6)::Any
│ (f)(%8)::Nothing
└── return nothing
) => Nothing
Ref{typeof(m)} is not the same as typeof(Ref(m)).
julia> typeof(Ref(m))
Base.RefValue{var"#inner#29"{Vector{Float64}, Float64, UnitRange{Int64}}}