julia icon indicating copy to clipboard operation
julia copied to clipboard

Support execution of code from external abstract interpreters

Open vchuravy opened this issue 1 year ago • 1 comments

Currently external abstract interpreters are succesfully used to power analysis, compilation for external targets like GPU, but there currently doesn't exist a mechanism within Julia to execute the output of these abstract interpreters and the community has developed various work-arounds.

  • Enzyme.jl uses GPUCompiler.jl + LLVM.jl to spin up it's own external JIT and uses ccall to call from Julia native to Enzyme controlled land. It furthermore needs to detect dynamic callsites and replaces them to functions controlled by Enzyme.
  • CassetteOverlay.jl uses the "cassette" transform to replace all calls f(args...) to calls to overdub(ctx, f, args...). While this has worked in Cassette.jl for a long time, it also leads to hard to read backtraces and overly relies on generated functions.

This PR is a combination of two ideas:

  • The currently active compiler is a dynamically scoped value (implemented as a dedicated task-local for performance)
  • #52233 unifies the cache infrastructure for native and external compilers, thus allows for easy querying inside the runtime and compiler

Concretly this PR introduces abstract type CompilerInstance end that allows for the creation of temporary AbstractInterpreter instances, it then uses a new built-in call_within to switch between compiler instances, furthermore the compiler instance is also the owner of the corresponding CodeInstances.

After that it is mostly and exercise of threading the compiler instance through in the right places.

TODO

  • [ ] Go through all uses of explicit jl_nothing as a owner token and use the correct one.
  • [ ] Expose the compiler instance to reflection methods
  • [ ] Add abstract interpreter support for call_within
  • [ ] Concrete evaluation will need to use call_within
  • [ ] Interpreter support? JuliaInterpreter.jl support? Cthulhu support?
  • [ ] Finalizers?

Example: Cassette style tracer

This PR doesn't do anything about how to work with the IR and write compiler plugins, it also doesn't provide any hooks for compiler instances to modify the LLVM pipeline, but a below is a prototype of a Cassette type tracer. Where we have a prehook and a posthook function to execute over the callgraph.

const CC = Core.Compiler

import .CC: SSAValue, GlobalRef

struct Tracer <: CC.AbstractCompiler end
CC.abstract_interpreter(::Tracer, world::UInt) =
    TracerInterp(; world)

struct TracerInterp <: CC.AbstractInterpreter
    world::UInt
    inf_params::CC.InferenceParams
    opt_params::CC.OptimizationParams
    inf_cache::Vector{CC.InferenceResult}
    code_cache::CC.InternalCodeCache
    compiler::Tracer
    function TracerInterp(;
                world::UInt = Base.get_world_counter(),
                compiler::Tracer = Tracer(),
                inf_params::CC.InferenceParams = CC.InferenceParams(),
                opt_params::CC.OptimizationParams = CC.OptimizationParams(),
                inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
                code_cache::CC.InternalCodeCache = CC.InternalCodeCache(compiler))
        return new(world, inf_params, opt_params, inf_cache, code_cache, compiler)
    end
end

CC.InferenceParams(interp::TracerInterp) = interp.inf_params
CC.OptimizationParams(interp::TracerInterp) = interp.opt_params
CC.get_world_counter(interp::TracerInterp) = interp.world
CC.get_inference_cache(interp::TracerInterp) = interp.inf_cache
CC.code_cache(interp::TracerInterp) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
CC.cache_owner(interp::TracerInterp) = interp.compiler

import Core.Compiler: retrieve_code_info, maybe_validate_code
# Replace usage sited of `retrieve_code_info`, OptimizationState is one such, but in all interesting use-cases
# it is derived from an InferenceState. There is a third one in `typeinf_ext` in case the module forbids inference.
function CC.InferenceState(result::CC.InferenceResult, cache_mode::UInt8, interp::TracerInterp)
    world = CC.get_world_counter(interp)
    src = retrieve_code_info(result.linfo, world)
    src === nothing && return nothing
    maybe_validate_code(result.linfo, src, "lowered")
    src = transform(interp, result.linfo, src)
    maybe_validate_code(result.linfo, src, "transformed")
    return CC.InferenceState(result, src, cache_mode, interp)
end

##
# Cassette style early transform
##

# Allows for Cassette Pass transforms

function static_eval(mod, name)
    if Base.isbindingresolved(mod, name) && Base.isdefined(mod, name)
        return getfield(mod, name)
    else
        return nothing
    end
end

function prehook end
function posthook end

function transform(interp, mi, src)
    method = mi.def
    f = static_eval(method.module, method.name)
    ccall(:jl_, Cvoid, (Any,), mi)
    if f === Core._apply
        return src
    end
    if f isa Core.Builtin
        error("Transforming builtin")
    end
    # if f === prehook || f === posthook
    #     return src
    # end
    ci = copy(src)
    transform!(mi, ci)
    ci.ssavaluetypes = length(ci.code)
    # XXX we need to copy flags
    ci.ssaflags = [0x00 for _ in 1:length(ci.code)]
    # ccall(:jl_, Cvoid, (Any,), ci)
    return ci
end

function ir_element(x, code::Vector)
    while isa(x, Core.SSAValue)
        x = code[x.id]
    end
    return x
end

"""
    is_ir_element(x, y, code::Vector)

Return `true` if `x === y` or if `x` is an `SSAValue` such that
`is_ir_element(code[x.id], y, code)` is `true`.
See also: [`replace_match!`](@ref), [`insert_statements!`](@ref)
"""
function is_ir_element(x, y, code::Vector)
    result = false
    while true # break by default
        if x === y #
            result = true
            break
        elseif isa(x, Core.SSAValue)
            x = code[x.id]
        else
            break
        end
    end
    return result
end

"""
    insert_statements!(code::Vector, codelocs::Vector, stmtcount, newstmts)


For every statement `stmt` at position `i` in `code` for which `stmtcount(stmt, i)` returns
an `Int`, remove `stmt`, and in its place, insert the statements returned by
`newstmts(stmt, i)`. If `stmtcount(stmt, i)` returns `nothing`, leave `stmt` alone.

For every insertion, all downstream `SSAValue`s, label indices, etc. are incremented
appropriately according to number of inserted statements.

Proper usage of this function dictates that following properties hold true:

- `code` is expected to be a valid value for the `code` field of a `CodeInfo` object.
- `codelocs` is expected to be a valid value for the `codelocs` field of a `CodeInfo` object.
- `newstmts(stmt, i)` should return a `Vector` of valid IR statements.
- `stmtcount` and `newstmts` must obey `stmtcount(stmt, i) == length(newstmts(stmt, i))` if
    `isa(stmtcount(stmt, i), Int)`.

To gain a mental model for this function's behavior, consider the following scenario. Let's
say our `code` object contains several statements:
code = Any[oldstmt1, oldstmt2, oldstmt3, oldstmt4, oldstmt5, oldstmt6]
codelocs = Int[1, 2, 3, 4, 5, 6]

Let's also say that for our `stmtcount` returns `2` for `stmtcount(oldstmt2, 2)`, returns `3`
for `stmtcount(oldstmt5, 5)`, and returns `nothing` for all other inputs. From this setup, we
can think of `code`/`codelocs` being modified in the following manner:
newstmts2 = newstmts(oldstmt2, 2)
newstmts5 = newstmts(oldstmt5, 5)
code = Any[oldstmt1,
           newstmts2[1], newstmts2[2],
           oldstmt3, oldstmt4,
           newstmts5[1], newstmts5[2], newstmts5[3],
           oldstmt6]
codelocs = Int[1, 2, 2, 3, 4, 5, 5, 5, 6]

See also: [`replace_match!`](@ref), [`is_ir_element`](@ref)
"""
function insert_statements!(code, codelocs, stmtcount, newstmts)
    ssachangemap = fill(0, length(code))
    labelchangemap = fill(0, length(code))
    worklist = Tuple{Int,Int}[]
    for i in 1:length(code)
        stmt = code[i]
        nstmts = stmtcount(stmt, i)
        if nstmts !== nothing
            addedstmts = nstmts - 1
            push!(worklist, (i, addedstmts))
            ssachangemap[i] = addedstmts
            if i < length(code)
                labelchangemap[i + 1] = addedstmts
            end
        end
    end
    Core.Compiler.renumber_ir_elements!(code, ssachangemap, labelchangemap)
    for (i, addedstmts) in worklist
        i += ssachangemap[i] - addedstmts # correct the index for accumulated offsets
        stmts = newstmts(code[i], i)
        @assert length(stmts) == (addedstmts + 1)
        code[i] = stmts[end]
        for j in 1:(length(stmts) - 1) # insert in reverse to maintain the provided ordering
            insert!(code, i, stmts[end - j])
            insert!(codelocs, i, codelocs[i])
        end
    end
end

function transform!(mi, src)
    stmtcount = (x, i) -> begin
        isassign = Base.Meta.isexpr(x, :(=))
        stmt = isassign ? x.args[2] : x
        if Base.Meta.isexpr(stmt, :call)
            return 4
        end
        return nothing
    end
    newstmts = (x, i) -> begin
        callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
        isapplycall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply), src.code)
        isapplyiteratecall = is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply_iterate), src.code)
        if isapplycall || isapplyiteratecall
            callf = callstmt.args[2]
            callargs = callstmt.args[3:end]
            stmts = Any[
                Expr(:call,
                     GlobalRef(Core, :_call_within), nothing,
                     prehook, callf, callargs...),
                callstmt,
                Expr(:call,
                     GlobalRef(Core, :_call_within), nothing,
                     posthook, SSAValue(i + 1), callf, callargs...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
            ]
        else
            stmts = Any[
                Expr(:call, GlobalRef(Core, :_call_within), nothing, prehook, callstmt.args...),
                callstmt,
                Expr(:call, GlobalRef(Core, :_call_within), nothing, posthook, SSAValue(i + 1), callstmt.args...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], SSAValue(i + 1)) : SSAValue(i + 1)
            ]
        end
        return stmts
    end
    insert_statements!(src.code, src.codelocs, stmtcount, newstmts)
    return nothing
end

# TODO:
# - anything called from prehook is still instrumented
#   - either we can stop instrumentation / or switch to native (high overhead)
# - Similar problem with invokelatest, we somehow got recursion
# - invokelatest trace sees prehook call -- do we double transform?
struct Call
    parent
    f
    args
    children
end

# TODO: Handle task-safety
const call_tree = ScopedValue{Ref{Call}}()

function prehook(f, args...)
    parent = call_tree[][]
    current = Call(parent, f, args, Call[])
    push!(parent.children, current)
    call_tree[][] = current
end

function posthook(_, f, args...)
    current = call_tree[][]
    call_tree[][] = current.parent
end

function f()
end

function trace(f, args...)
    top = Call(nothing, f, args, Call[])
    @with call_tree => Ref(top) begin
        Base.invoke_within(Tracer(), f, args...)
    end
    return top
end

trace(f)

vchuravy avatar Jan 18 '24 16:01 vchuravy

One thing I particularly enjoyed about Cassette was that composition is well defined. This proposal ignores composition, compiler instances don't form a stack and there is no expectation that the output of one is meant to be consumed by another.

For Cassette uninferred IR was a viable communication layer, but with compiler instances customization can occur along many levels and we run into the pipeline ordering problem. The hope would be that using compiler instance we could build an actual "compiler plugins" infrastructure that allows for the registration of passes/intrinsics and provides sensible composition, but that seems further away and I do think we need some experimentation with compiler customization first before we tackle that.

vchuravy avatar Jan 18 '24 19:01 vchuravy