Enzyme.jl
Enzyme.jl copied to clipboard
Allow custom rule for constant arg/ret in rev mode
Codecov Report
Attention: Patch coverage is 91.17647%
with 9 lines
in your changes are missing coverage. Please review.
Project coverage is 70.55%. Comparing base (
724b9bc
) to head (73a3fc5
).
Files | Patch % | Lines |
---|---|---|
src/rules/customrules.jl | 90.21% | 9 Missing :warning: |
:exclamation: Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@ Coverage Diff @@
## main #1371 +/- ##
==========================================
- Coverage 75.40% 70.55% -4.86%
==========================================
Files 36 36
Lines 10671 10276 -395
==========================================
- Hits 8047 7250 -797
- Misses 2624 3026 +402
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
This works with arrays. I've checked now with KA using CPU()
and CUDABackend()
. What is weird is that with CUDABackend()
the rule gets triggered, however with CPU()
not.
Also, in the case of CUDABackend()
I get lots of these, with CPU()
not:
I get lots of:
┌ Warning: Type does not have a definite number of fields
│ T = Tuple{Vararg{Union{UInt64, String}}}
└ @ Enzyme /disk/mschanen/julia_depot/packages/GPUCompiler/U36Ed/src/utils.jl:59
MWE is here (local Enzyme_jll with Enzyme build is needed)
This only prints:
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}
Instead of
Forward synchronize for Const{CPU}
Reverse synchronize for Const{CPU}
Forward synchronize for Const{CUDABackend}
Reverse synchronize for Const{CUDABackend}
Could it be that Enzyme takes decisions based on the code synchronize(backend::T) where T
? I think in the case of CUDABackend()
it goes through the KA.synchronize(::CUDABackend) = CUDA.synchronize()
, whereas KA.synchronize(::CPU) = nothing
probably gets removed as dead code somehow before a rule is applied.
@vchuravy Didn't we observe Enzyme hitting weird stuff with the KA kernel rules? I somehow wonder whether some stage of Enzyme doesn't go through the whole function that has a rule defined.
Narrowed it down.
using Enzyme
using EnzymeCore
using EnzymeCore.EnzymeRules
struct MyConst end
struct MyConst2
v::Vector{Float64}
end
MyConst2() = MyConst2(zeros(5))
bar(x::MyConst)::Nothing = nothing
function bar(x::MyConst2)
x.v .*= 2.0
nothing
end
function foo(myconst)
bar(myconst)
return nothing
end
function EnzymeRules.augmented_primal(
config::Config,
func::Const{typeof(bar)},
::Type{Const{Nothing}},
myconst::T
) where T <: EnzymeCore.Annotation
println("bar aug_fwd rule $(typeof(myconst))")
return AugmentedReturn{Nothing, Nothing, Any}(
nothing, nothing, (nothing)
)
end
function EnzymeRules.reverse(
config::Config,
func::Const{typeof(bar)},
::Type{Const{Nothing}},
tape,
myconst
)
println("bar rev rule $(typeof(myconst))")
return (nothing,)
end
function driver(myconst)
println("Running $(typeof(myconst()))")
Enzyme.autodiff(
ReverseWithPrimal, foo, Const(myconst())
)
end
# Doesn't trigger rules above
driver(MyConst)
# Triggers rules above
driver(MyConst2)
This outputs
Running MyConst
Running MyConst2
bar aug_fwd rule Const{MyConst2}
bar rev rule Const{MyConst2}
when the rule should also be applied in the case of MyConst.
As per our chat, this should be all resolved and ready to be merged.