Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Add support for user-defined rules

Open vchuravy opened this issue 3 years ago • 9 comments

  • Allow user defined opt-out of inlining
  • Add enzyme_code_typed
  • WIP: check that method is not compiled out
  • Recursive codegen

vchuravy avatar Jan 23 '22 22:01 vchuravy

What is the development status of this now?

zsz00 avatar May 30 '22 08:05 zsz00

What is actually the current state of this? Is this coming any time soon? And, is there already a way how to define rules in some other way, I've seen it mentioned a few times.

maximilian-gelbrecht avatar Jun 28 '22 08:06 maximilian-gelbrecht

Pending a design decision on #172

vchuravy avatar Aug 04 '22 12:08 vchuravy

API is still in need of work and right now only adds custom rules for forward mode.

But nevertheless it does work, and has (and hopefully passes) tests, so I want to merge.

wsmoses avatar Aug 26 '22 04:08 wsmoses

The 1.6 issue seems like an inling issue?

wsmoses avatar Aug 26 '22 06:08 wsmoses

Codecov Report

Merging #177 (b5ce463) into main (0f3233d) will decrease coverage by 1.99%. The diff coverage is 70.71%.

@@            Coverage Diff             @@
##             main     #177      +/-   ##
==========================================
- Coverage   76.33%   74.33%   -2.00%     
==========================================
  Files          17       18       +1     
  Lines        4226     4419     +193     
==========================================
+ Hits         3226     3285      +59     
- Misses       1000     1134     +134     
Impacted Files Coverage Δ
src/Enzyme.jl 86.87% <ø> (ø)
src/compiler/reflection.jl 90.32% <ø> (ø)
src/compiler.jl 74.24% <67.77%> (-0.56%) :arrow_down:
src/rules.jl 83.33% <83.33%> (ø)
src/api.jl 58.13% <100.00%> (+0.24%) :arrow_up:
src/compiler/interpreter.jl 98.03% <100.00%> (-1.97%) :arrow_down:
src/compiler/pmap.jl 84.86% <100.00%> (+3.61%) :arrow_up:
src/compiler/orcv1.jl 0.00% <0.00%> (-80.77%) :arrow_down:
src/compiler/utils.jl 56.16% <0.00%> (-36.99%) :arrow_down:
src/compiler/validation.jl 59.73% <0.00%> (-4.63%) :arrow_down:
... and 5 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov-commenter avatar Aug 26 '22 06:08 codecov-commenter

Vetoing merging this for now, we should first release a version with the accumulated fixes before landing this PR.

vchuravy avatar Sep 05 '22 14:09 vchuravy

@vchuravy with the landed jll bump, have time to review this (and possibly ponder 1.6 no inling)

wsmoses avatar Sep 13 '22 12:09 wsmoses

I have a strong preference to focus on stabilizing Enzyme.jl for a new release, a stretchgoal for that is GC integration and visit this for the next minor bump of Enzyme.jl I am not convinced that we want to commit to the interface yet and I would like to see the current internal custom rules be ported over to this format so that we can evaluate whether or not it is the right approach,

vchuravy avatar Sep 14 '22 09:09 vchuravy

Haven't thought through reverse mode yet, but I think:

using EnzymeCore
import EnzymeCore.EnzymeRules: forward

f(x) = x^2

function forward(::Const{typeof(f)}, RT::Type{<:DuplicatedNoNeed}, x::Duplicated)
    return 10+2*x.val*x.dval
end

function forward(::Const{typeof(f)}, RT::Type{<:BatchDuplicatedNoNeed}, x::BatchDuplicated{T, N}) where {T, N}
    return NTuple{N, T}(1000+2*x.val*dv for dv in x.dval)
end

function forward(::Const{typeof(f)}, RT::Type{<:Duplicated}, x::Duplicated)
    return Duplicated(func.val(x.val), 100+2*x.val*x.dval)
end

function forward(::Const{typeof(f)}, RT::Type{<:BatchDuplicated}, x::BatchDuplicated{T, N}) where {T,N}
    return BatchDuplicated(func.val(x.val), NTuple{N, T}(10000+2*x.val*dv for dv in x.dval))
end

import EnzymeCore: Annotation
function has_frule(@nospecialize(f))
    TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{<:Annotation}, Vararg{<:Annotation}}
    isapplicable(forward, TT)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT))
    TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, Vararg{<:Annotation}}
    isapplicable(forward, TT)
end

# Do we need this one?
function has_frule(@nospecialize(f), @nospecialize(RT), @nospecialize(TT))
    TT = Base.unwrap_unionall(TT)
    TT = Tuple{<:Annotation{Core.Typeof(f)}, Type{RT}, TT.parameters...}
    isapplicable(forward, TT)
end

# Base.hasmethod is a precise match we want the broader query.
function isapplicable(@nospecialize(f), @nospecialize(TT); world=Base.get_world_counter())
    tt = Base.to_tuple_type(TT)
    sig = Base.signature_type(f, tt)
    return !isempty(Base._methods_by_ftype(sig, -1, world)) # TODO cheaper way of querying?
end

@assert has_frule(f)
@assert has_frule(f, Duplicated)
@assert has_frule(f, DuplicatedNoNeed)
@assert has_frule(f, BatchDuplicated)
@assert has_frule(f, BatchDuplicatedNoNeed)
@assert has_frule(f, Duplicated, Tuple{<:Duplicated})
@assert has_frule(f, DuplicatedNoNeed, Tuple{<:Duplicated})
@assert has_frule(f, BatchDuplicated, Tuple{<:BatchDuplicated})
@assert has_frule(f, BatchDuplicatedNoNeed, Tuple{<:BatchDuplicated})

@assert !has_frule(f, Duplicated, Tuple{<:BatchDuplicated})
@assert !has_frule(f, DuplicatedNoNeed, Tuple{<:BatchDuplicated})
@assert !has_frule(f, BatchDuplicated, Tuple{<:Duplicated})
@assert !has_frule(f, BatchDuplicatedNoNeed, Tuple{<:Duplicated})

Is a decent start for the forward API.

vchuravy avatar Oct 09 '22 02:10 vchuravy

Rebased in #589

vchuravy avatar Feb 03 '23 20:02 vchuravy