Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Allow custom rule for constant arg/ret in rev mode

Open wsmoses opened this issue 10 months ago • 5 comments

wsmoses avatar Mar 31 '24 12:03 wsmoses

Codecov Report

Attention: Patch coverage is 91.17647% with 9 lines in your changes are missing coverage. Please review.

Project coverage is 70.55%. Comparing base (724b9bc) to head (73a3fc5).

Files Patch % Lines
src/rules/customrules.jl 90.21% 9 Missing :warning:

:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1371      +/-   ##
==========================================
- Coverage   75.40%   70.55%   -4.86%     
==========================================
  Files          36       36              
  Lines       10671    10276     -395     
==========================================
- Hits         8047     7250     -797     
- Misses       2624     3026     +402     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov-commenter avatar Mar 31 '24 13:03 codecov-commenter

This works with arrays. I've checked now with KA using CPU() and CUDABackend(). What is weird is that with CUDABackend() the rule gets triggered, however with CPU() not.

Also, in the case of CUDABackend() I get lots of these, with CPU() not: I get lots of:

┌ Warning: Type does not have a definite number of fields
│   T = Tuple{Vararg{Union{UInt64, String}}}
└ @ Enzyme /disk/mschanen/julia_depot/packages/GPUCompiler/U36Ed/src/utils.jl:59

MWE is here (local Enzyme_jll with Enzyme build is needed)

This only prints:

Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}

Instead of

Forward synchronize for Const{CPU}
Reverse synchronize for Const{CPU}
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}

michel2323 avatar Apr 01 '24 20:04 michel2323

Could it be that Enzyme takes decisions based on the code synchronize(backend::T) where T? I think in the case of CUDABackend() it goes through the KA.synchronize(::CUDABackend) = CUDA.synchronize(), whereas KA.synchronize(::CPU) = nothing probably gets removed as dead code somehow before a rule is applied.

@vchuravy Didn't we observe Enzyme hitting weird stuff with the KA kernel rules? I somehow wonder whether some stage of Enzyme doesn't go through the whole function that has a rule defined.

michel2323 avatar Apr 01 '24 20:04 michel2323

Narrowed it down.

using Enzyme
using EnzymeCore
using EnzymeCore.EnzymeRules

struct MyConst end
struct MyConst2
    v::Vector{Float64}
end
MyConst2() = MyConst2(zeros(5))

bar(x::MyConst)::Nothing = nothing
function bar(x::MyConst2)
    x.v .*= 2.0
    nothing
end

function foo(myconst)
    bar(myconst)
    return nothing
end

function EnzymeRules.augmented_primal(
    config::Config,
    func::Const{typeof(bar)},
    ::Type{Const{Nothing}},
    myconst::T
) where T <: EnzymeCore.Annotation
    println("bar aug_fwd rule $(typeof(myconst))")
    return AugmentedReturn{Nothing, Nothing, Any}(
        nothing, nothing, (nothing)
    )
end

function EnzymeRules.reverse(
    config::Config,
    func::Const{typeof(bar)},
    ::Type{Const{Nothing}},
    tape,
    myconst
)
    println("bar rev rule $(typeof(myconst))")
    return (nothing,)
end

function driver(myconst)
    println("Running $(typeof(myconst()))")
    Enzyme.autodiff(
        ReverseWithPrimal, foo, Const(myconst())
    )
end

# Doesn't trigger rules above
driver(MyConst)
# Triggers rules above
driver(MyConst2)

This outputs

Running MyConst
Running MyConst2
bar aug_fwd rule Const{MyConst2}
bar rev rule Const{MyConst2}

when the rule should also be applied in the case of MyConst.

michel2323 avatar Apr 02 '24 15:04 michel2323

As per our chat, this should be all resolved and ready to be merged.

michel2323 avatar Apr 09 '24 15:04 michel2323