Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

EnzymeRules

Open vchuravy opened this issue 3 years ago • 6 comments

We need an easier way to register custom adjoints that are compatible with Enzyme. The tricky bit here is coming up with a good ABI.

The goal would be to somewhat mimic Diffractor,jl (cc: @simeonschaub), meaning that the rule selection would run before activity analysis and during Julia inference. Alternatively we could attempt to have a callback function that can create additional code after activity analysis, but I rather would like that kind of unlimited compiler recursion.

So taking Base.sin as an example. During type-inference we have Tuple{typeof(Base.sin), Float64} (and maybe the return type) and we somehow need to map that onto:

function tfunc_rrrule(::typeof(Base.sin), ::Type{Active}, ::Type{Active{T}}) where T
    function ∇sin(ret_grad, x::Active{T}, tape) where T
        return cos(x) * ret_grad
    end
end

function tfunc_rrrule(::typeof(Base.sin), ::Type{Const}, ::Type{Active{T}}) where T
    function ∇sin(ret_grad, x::Active{T}, tape) where T
        return zero(T)
    end
end

or:

function rrrule(::typeof(Base.sin), ret_grad::Active{T}, x::Active{T}, tape) where T
    return cos(x) * ret_grad
end

function rrrule(::typeof(Base.sin), ret_grad::Const{T}, x::Active{T}, tape) where T
    return zero(T)
end

vchuravy avatar Jan 12 '22 15:01 vchuravy

cf https://github.com/JuliaDiff/ChainRulesCore.jl/issues/452 and in particular the gist: https://gist.github.com/oxinabox/c6ad25c468b3108f8a799bda66c147f8/ which kinda shows how existing rules can be used with activity annotations.

I suggest that this should be done in a way that depends on ChainRules and ChainRulesCore as a prototype. then once we get it nice and stable we move it into those packages. Since similar things are wanted needed by operator overloading ADs like Nabla.

So that would probably push the tape etc into the first config argument. see https://juliadiff.org/ChainRulesCore.jl/dev/rule_author/superpowers/ruleconfig.html but we can always mess with that later.

oxinabox avatar Jan 12 '22 15:01 oxinabox

I am a bit wary of depending on ChainRules immediately since most of the rules are written using a closure passing/capture style. And I am not sure that we can make that work well with Enzyme. That's why in the example above the rule is explicitly receiving the input arguments.

But maybe @wsmoses has some clever idea for squaring the circle.

vchuravy avatar Jan 12 '22 16:01 vchuravy

I think the ChainRules and proposed EnzymeRules interface solve two orthogonal problems.

The need for the design proposed here is making an interface for Enzyme to take any Julia custom derivatives (besides ones we've internally hard-coded). This immediately presents questions as to how we resolve such custom rules (wrt type of argument, activity, etc, from the Enzyme internal conventions).

I think the chain rules dependency solves a related, but yet technically distinct problem of importing existing definitions for custom rules (as opposed to the former making an interface).

I'm inclined to continue on both fronts (and they should eventually meet once here the custom derivative registration interface has been created, and ChainRules supports activity annotations)

wsmoses avatar Jan 12 '22 17:01 wsmoses

module EnzymeRules
    """
        augmented_primal(::typeof(f), args...)

    Return the primal computation value and a tape
    """
    function augmented_primal end

    """
    Takes gradient of derivative, activity annotation, and tape
    """
    function reverse end
end

using Test
using EnzymeRules
using Enzyme

f(x) = x^2
EnzymeRules.augmented_primal(::typeof(f), x::Active) = (f(x.val), nothing)
EnzymeRules.reverse(::typeof(f), ::Tuple{Active}, dret, tape) = (2*dret,)

EnzymeRules.reverse(::typeof(f), ::Tuple{Const}, dret, tape) = (zero(dret),)
# or is this
EnzymeRules.reverse(::typeof(f), ::Tuple{Const}, dret, tape) = (NoTangent(),)


g(x, y) = x * y
EnzymeRules.augmented_primal(::typeof(g), x::Active, y::Active) = (g(x, y), (x, y))
EnzymeRules.augmented_primal(::typeof(g), x::Const, y::Active) = (g(x, y), x)
EnzymeRules.augmented_primal(::typeof(g), x::Active, y::Const) = (g(x, y), y)


EnzymeRules.reverse(::typeof(f), ::Tuple{Active, Active}, dret, (x, y)) = (y*dret, x*dret)

EnzymeRules.reverse(::typeof(f), ::Tuple{Const, Active}, dret, x) = (NoTangent(), x*dret)
EnzymeRules.reverse(::typeof(f), ::Tuple{Active, Const}, dret, y) = (y*dret, NoTangent())

# TODO:
# - Duplicated?
# - NoTangent() vs nothing
# - Combinatorial explosion
# - activity dispatch doesn't quite work yet

vchuravy avatar Aug 04 '22 12:08 vchuravy

Combinatorial explosion

Yeah.

ChrisRackauckas avatar Aug 04 '22 15:08 ChrisRackauckas

Per an earlier offline discussion. Active vs const is easy since a const rule can always fallback to active and not use the derivative result. Active and duplicated are incompatible so one value can only be one (Eg it is impossible for a float to be duplicated in reverse mode) Duplicated vs const should have correct available rules. Hopefully the design enables taking say a union of either and being able to do the duplicated update conditionally on a type check.

Nevertheless this is something meriting design iteration

wsmoses avatar Aug 04 '22 15:08 wsmoses

EnzymeRules has landed on main.

wsmoses avatar Feb 05 '23 01:02 wsmoses