Support execution of code from external abstract interpreters
Currently external abstract interpreters are succesfully used to power analysis, compilation for external targets like GPU, but there currently doesn't exist a mechanism within Julia to execute the output of these abstract interpreters and the community has developed various work-arounds.
- Enzyme.jl uses GPUCompiler.jl + LLVM.jl to spin up it's own external JIT and uses ccall to call from Julia native to Enzyme controlled land. It furthermore needs to detect dynamic callsites and replaces them to functions controlled by Enzyme.
- CassetteOverlay.jl uses the "cassette" transform to replace all calls
f(args...)to calls tooverdub(ctx, f, args...). While this has worked in Cassette.jl for a long time, it also leads to hard to read backtraces and overly relies on generated functions.
This PR is a combination of two ideas:
- The currently active compiler is a dynamically scoped value (implemented as a dedicated task-local for performance)
- #52233 unifies the cache infrastructure for native and external compilers, thus allows for easy querying inside the runtime and compiler
Concretly this PR introduces abstract type CompilerInstance end that allows for the creation of temporary AbstractInterpreter instances,
it then uses a new built-in call_within to switch between compiler instances, furthermore the compiler instance is also the owner of
the corresponding CodeInstances.
After that it is mostly and exercise of threading the compiler instance through in the right places.
TODO
- [ ] Go through all uses of explicit
jl_nothingas a owner token and use the correct one. - [ ] Expose the compiler instance to reflection methods
- [ ] Add abstract interpreter support for
call_within - [ ] Concrete evaluation will need to use
call_within - [ ] Interpreter support? JuliaInterpreter.jl support? Cthulhu support?
- [ ] Finalizers?
Example: Cassette style tracer
This PR doesn't do anything about how to work with the IR and write compiler plugins,
it also doesn't provide any hooks for compiler instances to modify the LLVM pipeline,
but a below is a prototype of a Cassette type tracer. Where we have a prehook
and a posthook function to execute over the callgraph.
const CC = Core.Compiler
import .CC: SSAValue, GlobalRef
struct Tracer <: CC.AbstractCompiler end
CC.abstract_interpreter(::Tracer, world::UInt) =
TracerInterp(; world)
struct TracerInterp <: CC.AbstractInterpreter
world::UInt
inf_params::CC.InferenceParams
opt_params::CC.OptimizationParams
inf_cache::Vector{CC.InferenceResult}
code_cache::CC.InternalCodeCache
compiler::Tracer
function TracerInterp(;
world::UInt = Base.get_world_counter(),
compiler::Tracer = Tracer(),
inf_params::CC.InferenceParams = CC.InferenceParams(),
opt_params::CC.OptimizationParams = CC.OptimizationParams(),
inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
code_cache::CC.InternalCodeCache = CC.InternalCodeCache(compiler))
return new(world, inf_params, opt_params, inf_cache, code_cache, compiler)
end
end
CC.InferenceParams(interp::TracerInterp) = interp.inf_params
CC.OptimizationParams(interp::TracerInterp) = interp.opt_params
CC.get_world_counter(interp::TracerInterp) = interp.world
CC.get_inference_cache(interp::TracerInterp) = interp.inf_cache
CC.code_cache(interp::TracerInterp) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
CC.cache_owner(interp::TracerInterp) = interp.compiler
import Core.Compiler: retrieve_code_info, maybe_validate_code
# Replace usage sited of `retrieve_code_info`, OptimizationState is one such, but in all interesting use-cases
# it is derived from an InferenceState. There is a third one in `typeinf_ext` in case the module forbids inference.
function CC.InferenceState(result::CC.InferenceResult, cache_mode::UInt8, interp::TracerInterp)
world = CC.get_world_counter(interp)
src = retrieve_code_info(result.linfo, world)
src === nothing && return nothing
maybe_validate_code(result.linfo, src, "lowered")
src = transform(interp, result.linfo, src)
maybe_validate_code(result.linfo, src, "transformed")
return CC.InferenceState(result, src, cache_mode, interp)
end
##
# Cassette style early transform
##
# Allows for Cassette Pass transforms
function static_eval(mod, name)
if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
return getfield(mod, name)
else
return nothing
end
end
function prehook end
function posthook end
function transform(interp, mi, src)
method = mi.def
f = static_eval(method.module, method.name)
ccall(:jl_, Cvoid, (Any,), mi)
if f === Core._apply
return src
end
if f isa Core.Builtin
error("Transforming builtin")
end
# if f === prehook || f === posthook
# return src
# end
ci = copy(src)
transform!(mi, ci)
ci.ssavaluetypes = length(ci.code)
# XXX we need to copy flags
ci.ssaflags = [0x00 for _ in 1:length(ci.code)]
# ccall(:jl_, Cvoid, (Any,), ci)
return ci
end
function ir_element(x, code::Vector)
while isa(x, Core.SSAValue)
x = code[x.id]
end
return x
end
"""
is_ir_element(x, y, code::Vector)
Return `true` if `x === y` or if `x` is an `SSAValue` such that
`is_ir_element(code[x.id], y, code)` is `true`.
See also: [`replace_match!`](@ref), [`insert_statements!`](@ref)
"""
function is_ir_element(x, y, code::Vector)
result = false
while true # break by default
if x === y #
result = true
break
elseif isa(x, Core.SSAValue)
x = code[x.id]
else
break
end
end
return result
end
"""
insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)
For every statement `stmt` at position `i` in `code` for which `stmtcount(stmt, i)` returns
an `Int`, remove `stmt`, and in its place, insert the statements returned by
`newstmts(stmt, i)`. If `stmtcount(stmt, i)` returns `nothing`, leave `stmt` alone.
For every insertion, all downstream `SSAValue`s, label indices, etc. are incremented
appropriately according to number of inserted statements.
Proper usage of this function dictates that following properties hold true:
- `code` is expected to be a valid value for the `code` field of a `CodeInfo` object.
- `codelocs` is expected to be a valid value for the `codelocs` field of a `CodeInfo` object.
- `newstmts(stmt, i)` should return a `Vector` of valid IR statements.
- `stmtcount` and `newstmts` must obey `stmtcount(stmt, i) == length(newstmts(stmt, i))` if
`isa(stmtcount(stmt, i), Int)`.
To gain a mental model for this function's behavior, consider the following scenario. Let's
say our `code` object contains several statements:
code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]
Let's also say that for our `stmtcount` returns `2` for `stmtcount(oldstmt2, 2)`, returns `3`
for `stmtcount(oldstmt5, 5)`, and returns `nothing` for all other inputs. From this setup, we
can think of `code`/`codelocs` being modified in the following manner:
newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
newstmts2[1], newstmts2[2],
oldstmt3, oldstmt4,
newstmts5[1], newstmts5[2], newstmts5[3],
oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]
See also: [`replace_match!`](@ref), [`is_ir_element`](@ref)
"""
function insert_statements!(code, codelocs, stmtcount, newstmts)
ssachangemap = fill(0, length(code))
labelchangemap = fill(0, length(code))
worklist = Tuple{Int,Int}[]
for i in 1:length(code)
stmt = code[i]
nstmts = stmtcount(stmt, i)
if nstmts !== nothing
addedstmts = nstmts - 1
push!(worklist, (i, addedstmts))
ssachangemap[i] = addedstmts
if i < length(code)
labelchangemap[i + 1] = addedstmts
end
end
end
Core.Compiler.renumber_ir_elements!(code, ssachangemap, labelchangemap)
for (i, addedstmts) in worklist
i += ssachangemap[i] - addedstmts # correct the index for accumulated offsets
stmts = newstmts(code[i], i)
@assert length(stmts) == (addedstmts + 1)
code[i] = stmts[end]
for j in 1:(length(stmts) - 1) # insert in reverse to maintain the provided ordering
insert!(code, i, stmts[end - j])
insert!(codelocs, i, codelocs[i])
end
end
end
function transform!(mi, src)
stmtcount = (x, i) -> begin
isassign = Base.Meta.isexpr(x, :(=))
stmt = isassign ? x.args[2] : x
if Base.Meta.isexpr(stmt, :call)
return 4
end
return nothing
end
newstmts = (x, i) -> begin
callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
isapplycall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply), src.code)
isapplyiteratecall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply_iterate), src.code)
if isapplycall || isapplyiteratecall
callf = callstmt.args[2]
callargs = callstmt.args[3:end]
stmts = Any[
Expr(:call,
GlobalRef(Core, :_call_within), nothing,
prehook, callf, callargs...),
callstmt,
Expr(:call,
GlobalRef(Core, :_call_within), nothing,
posthook, SSAValue(i + 1), callf, callargs...),
Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
]
else
stmts = Any[
Expr(:call, GlobalRef(Core, :_call_within), nothing, prehook, callstmt.args...),
callstmt,
Expr(:call, GlobalRef(Core, :_call_within), nothing, posthook, SSAValue(i + 1), callstmt.args...),
Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
]
end
return stmts
end
insert_statements!(src.code, src.codelocs, stmtcount, newstmts)
return nothing
end
# TODO:
# - anything called from prehook is still instrumented
# - either we can stop instrumentation / or switch to native (high overhead)
# - Similar problem with invokelatest, we somehow got recursion
# - invokelatest trace sees prehook call -- do we double transform?
struct Call
parent
f
args
children
end
# TODO: Handle task-safety
const call_tree = ScopedValue{Ref{Call}}()
function prehook(f, args...)
parent = call_tree[][]
current = Call(parent, f, args, Call[])
push!(parent.children, current)
call_tree[][] = current
end
function posthook(_, f, args...)
current = call_tree[][]
call_tree[][] = current.parent
end
function f()
end
function trace(f, args...)
top = Call(nothing, f, args, Call[])
@with call_tree => Ref(top) begin
Base.invoke_within(Tracer(), f, args...)
end
return top
end
trace(f)
One thing I particularly enjoyed about Cassette was that composition is well defined. This proposal ignores composition, compiler instances don't form a stack and there is no expectation that the output of one is meant to be consumed by another.
For Cassette uninferred IR was a viable communication layer, but with compiler instances customization can occur along many levels and we run into the pipeline ordering problem. The hope would be that using compiler instance we could build an actual "compiler plugins" infrastructure that allows for the registration of passes/intrinsics and provides sensible composition, but that seems further away and I do think we need some experimentation with compiler customization first before we tackle that.