Incorrect forward mode gradient for truncated normal distribution
MWE:
julia> import FiniteDifferences, Enzyme; using Distributions
julia> func = (a, b, x) -> logpdf(truncated(Normal(), a, b), x)
#29 (generic function with 1 method)
julia> args = (-0.3, 0.3, 0.1)
(-0.3, 0.3, 0.1)
julia> Enzyme.gradient(Enzyme.Forward, Enzyme.Const(func), args...)
(0.6272640800199636, -2.607264080019964, -1.09)
julia> Enzyme.gradient(Enzyme.Reverse, Enzyme.Const(func), args...)
(1.6172640800199638, -1.617264080019964, -0.1)
julia> FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), func, args...)
(1.6172640800151292, -1.6172640800193545, -0.0999999999999462)
On Enzyme v0.13.11.
A further reduction of this would be helpful to resolving
Would like to help, but unfortunately I'm too busy right now to find time to minimise further. Open for grabs for anyone to do it.
@yebai @penelopeysm or @willtebbutt would you have cycles to minimize this?
@wsmoses Someone from the Enzyme team has to take this over since it no longer depends on Turing.
no one has experience with distributions.jl
if you want it to work, you're going to have to meet us halfway
import Enzyme
import FiniteDifferences
using SpecialFunctions: erfc
function g(i, _not_used)
k = sin(sin(sin(1 * erfc(0.25) / 1)))
return (i, k)
end
function f(_not_used)
i = (0.0, 3.9555)
t = g(i, _not_used)
return t[1][2]
end
args = (0.0,)
Enzyme.gradient(Enzyme.Forward, Enzyme.Const(f), args...) # (3.9555,)
FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, args...) # (4.312107145607921e-14,)
This is the 'simplest' repro I can get, in that if you make any of the following further simplifications, Enzyme will give the right answer:
- Shrinking
ito contain only one element. In fact, this bug only happens iffreturnst[1][n]forn > 1. - Changing
i[1]to anything but 0.-
i[1]actually corresponds to the mean of the normal distribution in the original example, so if you definefunc = (a, b, x) -> logpdf(truncated(Normal(1.0, 1.0), a, b), x)this bug doesn't happen.
-
- Unwrapping any of the tuples (either in the arguments to
g, or its return value). - Inlining
gintof. - Removing literally anything from the definition of
k, even the multiplication and division by 1. (You can change thesintocosor any other trig function, but you can't remove any of them.)
Perhaps of interest, the gradient that Enzyme reports back is always the value of t[1][2] (illustrated here with 3.9555).
Perfect, that should be enough!
Thanks @penelopeysm !!
will be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/2052 pending jll bump