Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Incorrect forward mode gradient for truncated normal distribution

Open mhauru opened this issue 1 year ago • 1 comments

MWE:

julia> import FiniteDifferences, Enzyme; using Distributions

julia> func = (a, b, x) -> logpdf(truncated(Normal(), a, b), x)
#29 (generic function with 1 method)

julia> args = (-0.3, 0.3, 0.1)
(-0.3, 0.3, 0.1)

julia> Enzyme.gradient(Enzyme.Forward, Enzyme.Const(func), args...)
(0.6272640800199636, -2.607264080019964, -1.09)

julia> Enzyme.gradient(Enzyme.Reverse, Enzyme.Const(func), args...)
(1.6172640800199638, -1.617264080019964, -0.1)

julia> FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), func, args...)
(1.6172640800151292, -1.6172640800193545, -0.0999999999999462)

On Enzyme v0.13.11.

mhauru avatar Oct 22 '24 14:10 mhauru

A further reduction of this would be helpful to resolving

wsmoses avatar Oct 22 '24 22:10 wsmoses

Would like to help, but unfortunately I'm too busy right now to find time to minimise further. Open for grabs for anyone to do it.

mhauru avatar Oct 23 '24 16:10 mhauru

@yebai @penelopeysm or @willtebbutt would you have cycles to minimize this?

wsmoses avatar Nov 03 '24 21:11 wsmoses

@wsmoses Someone from the Enzyme team has to take this over since it no longer depends on Turing.

yebai avatar Nov 03 '24 22:11 yebai

no one has experience with distributions.jl

if you want it to work, you're going to have to meet us halfway

wsmoses avatar Nov 03 '24 22:11 wsmoses

import Enzyme
import FiniteDifferences
using SpecialFunctions: erfc

function g(i, _not_used)
    k = sin(sin(sin(1 * erfc(0.25) / 1)))
    return (i, k)
end

function f(_not_used)
    i = (0.0, 3.9555)
    t = g(i, _not_used)
    return t[1][2]
end

args = (0.0,)

Enzyme.gradient(Enzyme.Forward, Enzyme.Const(f), args...)                 # (3.9555,)
FiniteDifferences.grad(FiniteDifferences.central_fdm(4, 1), f, args...)   # (4.312107145607921e-14,)

This is the 'simplest' repro I can get, in that if you make any of the following further simplifications, Enzyme will give the right answer:

  1. Shrinking i to contain only one element. In fact, this bug only happens if f returns t[1][n] for n > 1.
  2. Changing i[1] to anything but 0.
    • i[1] actually corresponds to the mean of the normal distribution in the original example, so if you define func = (a, b, x) -> logpdf(truncated(Normal(1.0, 1.0), a, b), x) this bug doesn't happen.
  3. Unwrapping any of the tuples (either in the arguments to g, or its return value).
  4. Inlining g into f.
  5. Removing literally anything from the definition of k, even the multiplication and division by 1. (You can change the sin to cos or any other trig function, but you can't remove any of them.)

Perhaps of interest, the gradient that Enzyme reports back is always the value of t[1][2] (illustrated here with 3.9555).

penelopeysm avatar Nov 04 '24 01:11 penelopeysm

Perfect, that should be enough!

Thanks @penelopeysm !!

wsmoses avatar Nov 04 '24 02:11 wsmoses

will be fixed by https://github.com/EnzymeAD/Enzyme.jl/pull/2052 pending jll bump

wsmoses avatar Nov 04 '24 04:11 wsmoses