Enzyme.jl icon indicating copy to clipboard operation
Enzyme.jl copied to clipboard

Enzyme's forward diff gives incorrect derivative of logsumexp when both arguments are zero

Open alexandrebouchard opened this issue 2 months ago • 5 comments

Thanks for the awesome project!

I hit a corner case where Enzyme's forward diff gives incorrect derivative of logsumexp(x, y). It occurs when both x and y are zero and taking derivative with respect to y. The problem does not occur with ForwardDiff.jl, only with Enzyme (I haven't tested reverse diffs).

When y is very small, Enzyme and ForwardDiff are consistent. But when y is exactly zero the values differ between Enzyme and ForwardDiff.jl; by a limiting argument and direct calculation, ForwardDiff gives the correct answer. Both StatsFuns and LogExpFunctions implementations are affected by this bug. No bug with a naive log(exp(0.0 + exp(x)) implementation.

I was wondering if you have some guess at what is happening? Thanks in advance.

Here is an example:

using Enzyme, ForwardDiff, StatsFuns, LogExpFunctions

naive_logsumexp(x) = log(exp(0.0) + exp(x))
stats_fun_logsumexp(x) = StatsFuns.logsumexp(0.0, x)
stats_fun_logaddexp(x) = StatsFuns.logaddexp(0.0, x)
logexpfunctions_logsumexp(x) = LogExpFunctions.logsumexp(0.0, x)

function test_ads(fct, point)
    @show fct, point
    enz = autodiff(Enzyme.Forward, fct, Duplicated(point, 1.))[1] 
    fd = ForwardDiff.derivative(fct, point)
    @show enz, fd, enz ≈ fd
end

function minimum_reproducible_ad_bug()
    for value in [0.0, 0.00001]
        test_ads(naive_logsumexp, value)
        test_ads(stats_fun_logsumexp, value)
        test_ads(stats_fun_logaddexp, value)
        test_ads(logexpfunctions_logsumexp, value)
    end
end

gives

julia> minimum_reproducible_ad_bug()
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.stats_fun_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.stats_fun_logaddexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.logexpfunctions_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.stats_fun_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.stats_fun_logaddexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.logexpfunctions_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)

alexandrebouchard avatar Oct 22 '25 15:10 alexandrebouchard

Thinking of it, the culprit must be a branching handling special cases in logsumexp implementations. E.g., https://github.com/JuliaStats/LogExpFunctions.jl/blob/77c5cf030b58b14f118f237b6b518005499a7f40/src/logsumexp.jl#L163

I suppose ForwardDiff must have custom rules to prevent that issue.

alexandrebouchard avatar Oct 22 '25 17:10 alexandrebouchard

Based on this hypothesis, I suppose the solution would be to write custom rules in these 2 projects. However I am curious if there would be a way to avoid silent failure here and instead have the forward AD print out an error. Presumably this kind of situation would occur when the dispatch of == involves a dual number, which is under Enzyme's control.

Also FYI, it seems the vector version of logsumexp is also affected, in that case, ForwardDIff also silently gives an erroneous answer:

naive_logsumexp(x) = log(exp(0.0) + exp(x))
stats_fun_logsumexp(x) = StatsFuns.logsumexp(0.0, x)
stats_fun_logsumexp2(x) = StatsFuns.logsumexp([0.0, x])
stats_fun_logsumexp3(x) = StatsFuns.logsumexp([x, 0.0])
stats_fun_logaddexp(x) = StatsFuns.logaddexp(0.0, x)
logexpfunctions_logsumexp(x) = LogExpFunctions.logsumexp(0.0, x)

function test_ads(fct, point)
    @show fct, point
    enz = autodiff(Enzyme.Forward, fct, Duplicated(point, 1.))[1] 
    fd = ForwardDiff.derivative(fct, point)
    @show enz, fd, enz ≈ fd
end

function minimum_reproducible_ad_bug()
    for value in [0.0, 0.00001]
        for f in [naive_logsumexp, naive_logsumexp, stats_fun_logsumexp2, stats_fun_logsumexp3, stats_fun_logaddexp, logexpfunctions_logsumexp]
            test_ads(f, value)
        end
    end
end
julia> minimum_reproducible_ad_bug()
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.stats_fun_logsumexp2, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.0, true)
(fct, point) = (Main.stats_fun_logsumexp3, 0.0)
(enz, fd, enz ≈ fd) = (1.0, 1.0, true)
(fct, point) = (Main.stats_fun_logaddexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.logexpfunctions_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.stats_fun_logsumexp2, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025000000001, true)
(fct, point) = (Main.stats_fun_logsumexp3, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025000000001, true)
(fct, point) = (Main.stats_fun_logaddexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.logexpfunctions_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)

alexandrebouchard avatar Oct 22 '25 17:10 alexandrebouchard

StatsFun.jl does have a ChainRules extension, but I see nothing for logsumexp https://github.com/JuliaStats/StatsFuns.jl/blob/master/ext/StatsFunsChainRulesCoreExt.jl

vchuravy avatar Oct 22 '25 19:10 vchuravy

Ah it might be this: https://github.com/JuliaStats/LogExpFunctions.jl/blob/77c5cf030b58b14f118f237b6b518005499a7f40/ext/LogExpFunctionsChainRulesCoreExt.jl#L136

vchuravy avatar Oct 22 '25 19:10 vchuravy

However I am curious if there would be a way to avoid silent failure here and instead have the forward AD print out an error.

This is tricky, Enzyme handles == just find, but calculates the sub-gradient that matches your function definition. As an example the abs function has three possible gradient values at 0.0 and and Enzyme chooses the one that matches your implementation.

Also see for other cases of AD pitfalls https://wires.onlinelibrary.wiley.com/doi/full/10.1002/widm.1555#widm1555-sec-0003-title

vchuravy avatar Oct 22 '25 19:10 vchuravy