Enzyme.jl
Enzyme.jl copied to clipboard
Allow custom rule for constant arg/ret in rev mode
Codecov Report
Attention: Patch coverage is 91.17647% with 9 lines in your changes are missing coverage. Please review.
Project coverage is 70.55%. Comparing base (
724b9bc) to head (73a3fc5).
| Files | Patch % | Lines |
|---|---|---|
| src/rules/customrules.jl | 90.21% | 9 Missing :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## main #1371 +/- ##
==========================================
- Coverage 75.40% 70.55% -4.86%
==========================================
Files 36 36
Lines 10671 10276 -395
==========================================
- Hits 8047 7250 -797
- Misses 2624 3026 +402
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
This works with arrays. I've checked now with KA using CPU() and CUDABackend(). What is weird is that with CUDABackend() the rule gets triggered, however with CPU() not.
Also, in the case of CUDABackend() I get lots of these, with CPU() not:
I get lots of:
┌ Warning: Type does not have a definite number of fields
│ T = Tuple{Vararg{Union{UInt64, String}}}
└ @ Enzyme /disk/mschanen/julia_depot/packages/GPUCompiler/U36Ed/src/utils.jl:59
MWE is here (local Enzyme_jll with Enzyme build is needed)
This only prints:
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}
Instead of
Forward synchronize for Const{CPU}
Reverse synchronize for Const{CPU}
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}
Could it be that Enzyme takes decisions based on the code synchronize(backend::T) where T? I think in the case of CUDABackend() it goes through the KA.synchronize(::CUDABackend) = CUDA.synchronize(), whereas KA.synchronize(::CPU) = nothing probably gets removed as dead code somehow before a rule is applied.
@vchuravy Didn't we observe Enzyme hitting weird stuff with the KA kernel rules? I somehow wonder whether some stage of Enzyme doesn't go through the whole function that has a rule defined.
Narrowed it down.
using Enzyme
using EnzymeCore
using EnzymeCore.EnzymeRules
struct MyConst end
struct MyConst2
v::Vector{Float64}
end
MyConst2() = MyConst2(zeros(5))
bar(x::MyConst)::Nothing = nothing
function bar(x::MyConst2)
x.v .*= 2.0
nothing
end
function foo(myconst)
bar(myconst)
return nothing
end
function EnzymeRules.augmented_primal(
config::Config,
func::Const{typeof(bar)},
::Type{Const{Nothing}},
myconst::T
) where T <: EnzymeCore.Annotation
println("bar aug_fwd rule $(typeof(myconst))")
return AugmentedReturn{Nothing, Nothing, Any}(
nothing, nothing, (nothing)
)
end
function EnzymeRules.reverse(
config::Config,
func::Const{typeof(bar)},
::Type{Const{Nothing}},
tape,
myconst
)
println("bar rev rule $(typeof(myconst))")
return (nothing,)
end
function driver(myconst)
println("Running $(typeof(myconst()))")
Enzyme.autodiff(
ReverseWithPrimal, foo, Const(myconst())
)
end
# Doesn't trigger rules above
driver(MyConst)
# Triggers rules above
driver(MyConst2)
This outputs
Running MyConst
Running MyConst2
bar aug_fwd rule Const{MyConst2}
bar rev rule Const{MyConst2}
when the rule should also be applied in the case of MyConst.
As per our chat, this should be all resolved and ready to be merged.