Enzyme's forward diff gives incorrect derivative of logsumexp when both arguments are zero
Thanks for the awesome project!
I hit a corner case where Enzyme's forward diff gives incorrect derivative of logsumexp(x, y). It occurs when both x and y are zero and taking derivative with respect to y. The problem does not occur with ForwardDiff.jl, only with Enzyme (I haven't tested reverse diffs).
When y is very small, Enzyme and ForwardDiff are consistent. But when y is exactly zero the values differ between Enzyme and ForwardDiff.jl; by a limiting argument and direct calculation, ForwardDiff gives the correct answer. Both StatsFuns and LogExpFunctions implementations are affected by this bug. No bug with a naive log(exp(0.0 + exp(x)) implementation.
I was wondering if you have some guess at what is happening? Thanks in advance.
Here is an example:
using Enzyme, ForwardDiff, StatsFuns, LogExpFunctions
naive_logsumexp(x) = log(exp(0.0) + exp(x))
stats_fun_logsumexp(x) = StatsFuns.logsumexp(0.0, x)
stats_fun_logaddexp(x) = StatsFuns.logaddexp(0.0, x)
logexpfunctions_logsumexp(x) = LogExpFunctions.logsumexp(0.0, x)
function test_ads(fct, point)
@show fct, point
enz = autodiff(Enzyme.Forward, fct, Duplicated(point, 1.))[1]
fd = ForwardDiff.derivative(fct, point)
@show enz, fd, enz ≈ fd
end
function minimum_reproducible_ad_bug()
for value in [0.0, 0.00001]
test_ads(naive_logsumexp, value)
test_ads(stats_fun_logsumexp, value)
test_ads(stats_fun_logaddexp, value)
test_ads(logexpfunctions_logsumexp, value)
end
end
gives
julia> minimum_reproducible_ad_bug()
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.stats_fun_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.stats_fun_logaddexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.logexpfunctions_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.stats_fun_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.stats_fun_logaddexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.logexpfunctions_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
Thinking of it, the culprit must be a branching handling special cases in logsumexp implementations. E.g., https://github.com/JuliaStats/LogExpFunctions.jl/blob/77c5cf030b58b14f118f237b6b518005499a7f40/src/logsumexp.jl#L163
I suppose ForwardDiff must have custom rules to prevent that issue.
Based on this hypothesis, I suppose the solution would be to write custom rules in these 2 projects. However I am curious if there would be a way to avoid silent failure here and instead have the forward AD print out an error. Presumably this kind of situation would occur when the dispatch of == involves a dual number, which is under Enzyme's control.
Also FYI, it seems the vector version of logsumexp is also affected, in that case, ForwardDIff also silently gives an erroneous answer:
naive_logsumexp(x) = log(exp(0.0) + exp(x))
stats_fun_logsumexp(x) = StatsFuns.logsumexp(0.0, x)
stats_fun_logsumexp2(x) = StatsFuns.logsumexp([0.0, x])
stats_fun_logsumexp3(x) = StatsFuns.logsumexp([x, 0.0])
stats_fun_logaddexp(x) = StatsFuns.logaddexp(0.0, x)
logexpfunctions_logsumexp(x) = LogExpFunctions.logsumexp(0.0, x)
function test_ads(fct, point)
@show fct, point
enz = autodiff(Enzyme.Forward, fct, Duplicated(point, 1.))[1]
fd = ForwardDiff.derivative(fct, point)
@show enz, fd, enz ≈ fd
end
function minimum_reproducible_ad_bug()
for value in [0.0, 0.00001]
for f in [naive_logsumexp, naive_logsumexp, stats_fun_logsumexp2, stats_fun_logsumexp3, stats_fun_logaddexp, logexpfunctions_logsumexp]
test_ads(f, value)
end
end
end
julia> minimum_reproducible_ad_bug()
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.naive_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.5, 0.5, true)
(fct, point) = (Main.stats_fun_logsumexp2, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.0, true)
(fct, point) = (Main.stats_fun_logsumexp3, 0.0)
(enz, fd, enz ≈ fd) = (1.0, 1.0, true)
(fct, point) = (Main.stats_fun_logaddexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.logexpfunctions_logsumexp, 0.0)
(enz, fd, enz ≈ fd) = (0.0, 0.5, false)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.naive_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000024999999999, 0.5000024999999999, true)
(fct, point) = (Main.stats_fun_logsumexp2, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025000000001, true)
(fct, point) = (Main.stats_fun_logsumexp3, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025000000001, true)
(fct, point) = (Main.stats_fun_logaddexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
(fct, point) = (Main.logexpfunctions_logsumexp, 1.0e-5)
(enz, fd, enz ≈ fd) = (0.5000025, 0.5000025, true)
StatsFun.jl does have a ChainRules extension, but I see nothing for logsumexp https://github.com/JuliaStats/StatsFuns.jl/blob/master/ext/StatsFunsChainRulesCoreExt.jl
Ah it might be this: https://github.com/JuliaStats/LogExpFunctions.jl/blob/77c5cf030b58b14f118f237b6b518005499a7f40/ext/LogExpFunctionsChainRulesCoreExt.jl#L136
However I am curious if there would be a way to avoid silent failure here and instead have the forward AD print out an error.
This is tricky, Enzyme handles == just find, but calculates the sub-gradient that matches your function definition. As an example the abs function has three possible gradient values at 0.0 and and Enzyme chooses the one that matches your implementation.
Also see for other cases of AD pitfalls https://wires.onlinelibrary.wiley.com/doi/full/10.1002/widm.1555#widm1555-sec-0003-title