Enzyme.jl
Enzyme.jl copied to clipboard
Strange behavior of `Const` annotation
Here's an MWE involving differentiation of a function with closed-over cache storage.
using Enzyme
struct FuncWithCacheSum{C}
cache::C
end
struct FuncWithCacheFirst{C}
cache::C
end
function (f::FuncWithCacheSum)(x)
(; cache) = f
cache[1] = x[1]
y = sum(abs2, cache) # only difference is here
cache[1] = zero(eltype(cache))
return y
end
function (f::FuncWithCacheFirst)(x)
(; cache) = f
cache[1] = x[1]
y = abs2(cache[1]) # only difference is here
cache[1] = zero(eltype(cache))
return y
end
apply(x, f::F) where {F} = f(x)
x = [3.0]
fs = FuncWithCacheSum([0.0])
ff = FuncWithCacheFirst([0.0])
julia> ff(x) == fs(x)
true
julia> gradient(Reverse, apply, x, Const(fs))[1] # zero as expected
1-element Vector{Float64}:
0.0
julia> gradient(Reverse, apply, x, Const(ff))[1] # non-zero???
1-element Vector{Float64}:
6.0
Is it expected that the Const annotation doesn't seem to work for the second function?
LLVM can sometimes perform "store-forwarding" and so when it sees a store followed directly by a load, it will forward the value of the store directly, this then cause the activity to not be truncated, but rather follow the value.
We would likely need an "activity blocking" intrinsic instead of relaying on the semantics of Const