ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Wrong results for forward-mode `exp!` half of the time

Open baggepinnen opened this issue 2 years ago • 5 comments

If I repeatedly run the example below, I get the wrong result for the gradient through exp! about half of the time.

using LinearAlgebra, ForwardDiff, FiniteDiff, ForwardDiffChainRules

@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})

function test_exp(x)
    X = copy(reshape(x, 4, 4))
    X2 = LinearAlgebra.exp!(X)
    sum(X2)
end

for i = 1:20
    x = randn(16)
    X = reshape(x, 4, 4)
    g1 = ForwardDiff.gradient(test_exp, x)
    g2 = FiniteDiff.finite_difference_gradient(test_exp, x)
    @show norm(g1-g2)
end
norm(g1 - g2) = 3.2745988814567806
norm(g1 - g2) = 2.7005934051461515e-9
norm(g1 - g2) = 3.535502368190921
norm(g1 - g2) = 5.376574873194121e-10
norm(g1 - g2) = 2.4158271822718778e-9
norm(g1 - g2) = 3.0885755390647527e-10
norm(g1 - g2) = 4.215282668056846
norm(g1 - g2) = 1.7888448238515218
norm(g1 - g2) = 2.1068558951714456e-10
norm(g1 - g2) = 8.090031857043094
norm(g1 - g2) = 5.8514613833452644
norm(g1 - g2) = 5.859275463330073e-10
norm(g1 - g2) = 3.3486620002856527e-10
norm(g1 - g2) = 1.1628716126438234
norm(g1 - g2) = 2.72443511328846
norm(g1 - g2) = 1.5771975088961793e-10
norm(g1 - g2) = 1.2073055237629486e-9
norm(g1 - g2) = 8.255800634241801
norm(g1 - g2) = 2.2459662479919337e-10
norm(g1 - g2) = 4.433466845638335

Also reported in https://github.com/ThummeTo/ForwardDiffChainRules.jl/issues/14

baggepinnen avatar May 17 '23 08:05 baggepinnen

This only appears to be a problem for the non-symmetric version of exp!. When I create a symmetric matrix X = X'X and switch to FiniteDifferences.jl for more accurate testing, it works just fine. The non-symmetric input matrix is still problematic though

using LinearAlgebra, ForwardDiff, ForwardDiffChainRules, FiniteDifferences
@ForwardDiff_frule LinearAlgebra.exp!(x1::AbstractMatrix{<:ForwardDiff.Dual})
function test_exp(x)
    X = copy(reshape(x, 4, 4))
    X2 = LinearAlgebra.exp!(X)
    sum(X2)
end

for i = 1:20
    X = randn(4,4)
    X = X'X
    x = vec(X)
    g1 = ForwardDiff.gradient(test_exp, x)
    g2 = FiniteDifferences.grad(central_fdm(5, 1), test_exp, x)[1]
    @show norm(g1-g2)
end

I've also tested the reverse rule using Zygote and there is no problem in reverse

baggepinnen avatar May 17 '23 08:05 baggepinnen

@sethaxen do you think you might have time to look into this?

oxinabox avatar May 17 '23 11:05 oxinabox

Yes, I can look into this.

@baggepinnen I'm not familiar with ForwardDiffChainRules. Can you provide an MWE that exhibits the observed failure using just ChainRules?

Using our own testing machinery, I am unable to observe any failures on 1000x the number of random matrices:

julia> using ChainRules, ChainRulesTestUtils, LinearAlgebra, Random, Test

julia> Random.seed!(42);

julia> @testset "exp!" begin
           Xs = (randn(4, 4) for _ in 1:20_000)
           @testset for X in Xs
               test_frule(LinearAlgebra.exp!, X)
           end
       end;
Test Summary: |  Pass  Total   Time
exp!          | 80000  80000  13.2s

Btw, in a fresh environment, your example errors on my machine in the for loop with:

ERROR: MethodError: no method matching iterate(::Nothing)

Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen})
   @ Base range.jl:880
  iterate(::Union{LinRange, StepRangeLen}, ::Integer)
   @ Base range.jl:880
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}}
   @ Base dict.jl:698
  ...

Stacktrace:
 [1] indexed_iterate(I::Nothing, i::Int64)
   @ Base ./tuple.jl:91
 [2] exp!(x1::Matrix{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ~/.julia/packages/ForwardDiffChainRules/2Xt9G/src/ForwardDiffChainRules.jl:81
 [3] test_exp(x::Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}})
   @ Main ./REPL[4]:3
 [4] chunk_mode_gradient(f::typeof(test_exp), x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:123
 [5] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}}, ::Val{true})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:21
 [6] gradient(f::Function, x::Vector{Float64}, cfg::ForwardDiff.GradientConfig{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8, Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(test_exp), Float64}, Float64, 8}}})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [7] gradient(f::Function, x::Vector{Float64})
   @ ForwardDiff ~/.julia/packages/ForwardDiff/vXysl/src/gradient.jl:17
 [8] top-level scope
   @ ./REPL[5]:4

sethaxen avatar May 17 '23 12:05 sethaxen

I think the problem is related to how ForwardDiffChainRules deals with (doesn't deal with) the fact that exp! mutates its input argument, by adding a call to copy on the input argument before each invokation of the frule I get the correct results. This is probably an issue with ForwardDiffChainRules then.

baggepinnen avatar May 17 '23 13:05 baggepinnen

Sounds right. This line calls frule repeatedly on the same primals, so it assumes the function is nonmutating: https://github.com/ThummeTo/ForwardDiffChainRules.jl/blob/d70301a28f61250c3168446c4b147b195ceee117/src/ForwardDiffChainRules.jl#L88

sethaxen avatar May 17 '23 16:05 sethaxen