ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

`NaN`s in jacobian of a function that uses a `StaticArray` and `norm`

Open ferrolho opened this issue 4 years ago • 11 comments

Hi! I am having trouble differentiating a function with ForwardDiff.jl. I don't know why, but the resulting jacobian contains NaNs. Below is a MWE to kickstart the discussion.

Consider the following functions:

using BenchmarkTools, ForwardDiff, LinearAlgebra, StaticArrays

foo2(A) = SVector{4}(norm(r, 2) for r = eachrow(A))

function foo1!(out, x; μ = 0.8 / √2)
    λ = SVector{3}(@view x[1:3])
    K = SMatrix{3,3}(@view x[4:12])

    A_λ = @SMatrix [ 1   0  -1   0 ;
                     0   1   0  -1 ;
                     μ   μ   μ   μ ]

    out[1:4] = A_λ' * λ + foo2(A_λ' * K)

    out
end

Let's try out foo1!:

julia> x = rand(12);

julia> out = zeros(4);

julia> foo1!(out, x)
4-element Vector{Float64}:
 2.403435298832313
 2.0030472808350077
 0.7216049166673891
 0.6293600802591698

foo1! is type-stable and does not perform dynamic allocations:

julia> @btime $foo1!($out, $x)
  21.684 ns (0 allocations: 0 bytes)
4-element Vector{Float64}:
 2.403435298832313
 2.0030472808350077
 0.7216049166673891
 0.6293600802591698

We can use ForwardDiff.jl to compute the jacobian of foo1!:

julia> ForwardDiff.jacobian(foo1!, out, x)
4×12 Matrix{Float64}:
  1.0   0.0  0.565685  0.606882  0.0        0.343304  …   0.353444  0.491234  0.0        0.277884
  0.0   1.0  0.565685  0.0       0.744664   0.421245      0.22503   0.0       0.535939   0.303173
 -1.0   0.0  0.565685  0.722007  0.0       -0.408429     -0.362283  0.261825  0.0       -0.14811
  0.0  -1.0  0.565685  0.0       0.930926  -0.526611     -0.154069  0.0       0.243306  -0.137634

However, if the inputs are all zeros, the jacobian will contain NaNs:

julia> x = zeros(12);

julia> ForwardDiff.jacobian(foo1!, out, x)
4×12 Matrix{Float64}:
 NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN
 NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN
 NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN
 NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN  NaN

But if we change the type of the matrix A_λ (in foo1!) from an SMatrix to a normal Matrix, the jacobian will be evaluated properly:

julia> function foo1!(out, x; μ = 0.8 / √2)
           λ = SVector{3}(@view x[1:3])
           K = SMatrix{3,3}(@view x[4:12])

           A_λ = [ 1   0  -1   0 ;
                   0   1   0  -1 ;
                   μ   μ   μ   μ ]

           out[1:4] = A_λ' * λ + foo2(A_λ' * K)

           out
       end
foo1! (generic function with 1 method)

julia> ForwardDiff.jacobian(foo1!, out, x)
4×12 Matrix{Float64}:
  1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   1.0   0.0  0.565685
  0.0   1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   0.0   1.0  0.565685
 -1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  -1.0   0.0  0.565685
  0.0  -1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   0.0  -1.0  0.565685

Moreover, I have also observed that the jacobian will not contain NaNs if we use the 1-norm or the Inf-norm, even if we keep A_λ as an SMatrix, i.e.,

foo2(A) = SVector{4}(norm(r, 1) for r = eachrow(A))

or

foo2(A) = SVector{4}(norm(r, Inf) for r = eachrow(A))

I am not sure if this is a bug or if I am doing something wrong... Can someone help me figure it out? Thank you in advance!

ferrolho avatar Oct 02 '21 18:10 ferrolho

I have read the docs section Fixing NaN/Inf Issues under User Documentation > Advanced Usage Guide. Indeed, if I set the NANSAFE_MODE_ENABLED constant to true, the jacobian no longer contains NaNs. However, the evaluated jacobian is different compared to the version of foo1! which defines A_λ as a normal Matrix.

If A_λ is a Matrix{Float64}:

julia> ForwardDiff.jacobian(foo1!, out, x)
4×12 Matrix{Float64}:
  1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   1.0   0.0  0.565685
  0.0   1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   0.0   1.0  0.565685
 -1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  -1.0   0.0  0.565685
  0.0  -1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0   0.0  -1.0  0.565685

If A_λ is a SMatrix{3, 4, Float64, 12}:

julia> ForwardDiff.jacobian(foo1!, out, x)
4×12 Matrix{Float64}:
  1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
  0.0   1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 -1.0   0.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
  0.0  -1.0  0.565685  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

I would appreciate some guidance on what I should do. I.e., should I define NANSAFE_MODE_ENABLED as true from now on (and take the ~5%-10% performance hit)? Or is there something else I can do to fix this issue, without having to toggle NANSAFE_MODE_ENABLED? Thank you!

ferrolho avatar Oct 02 '21 19:10 ferrolho

Regardless of NaNs, it would be good to figure out why Matrix and SMatrix differ. I guess just peppering some print statements throughout the code should reveal where things diverge.

KristofferC avatar Oct 02 '21 19:10 KristofferC

Hi, @KristofferC. Thank you for your reply. You mean in ForwardDiff.jl? Where should I start? :sweat_smile:

ferrolho avatar Oct 02 '21 19:10 ferrolho

With StaticArrays 0.12.4 I get the same result between static matrix and matrix running the example in the first post. With StaticArrays 1.2.13 I get your result. So something happened there.

Bisecting points to https://github.com/JuliaArrays/StaticArrays.jl/pull/908 as the offending PR.

KristofferC avatar Oct 02 '21 20:10 KristofferC

A more minimal example is

julia> using StaticArrays, ForwardDiff

julia> ForwardDiff.gradient(norm, [0,0])
2-element Vector{Float64}:
 0.0
 1.0

julia> ForwardDiff.gradient(norm, SA[0,0])
2-element SVector{2, Float64} with indices SOneTo(2):
 NaN
 NaN

With a hand-written norm, we are evaluating at sqrt(0) where the gradient is infinite. The non-NaN answer is something like a choice of subgradient:

julia> ForwardDiff.gradient(x -> sqrt(sum(abs2, x)), [0,0])
2-element Vector{Float64}:
 NaN
 NaN

julia> ForwardDiff.gradient(x -> sqrt(sum(x .^ 2)), [0, eps()])
2-element Vector{Float64}:
 0.0
 1.0

julia> ForwardDiff.gradient(x -> sqrt(sum(x .^ 2)), [eps(), 0])
2-element Vector{Float64}:
 1.0
 0.0

Note also that because it's caused by an explicit branch (at least according to https://github.com/JuliaArrays/StaticArrays.jl/pull/964, didn't check LinearAlgebra's source) it will be changed by https://github.com/JuliaDiff/ForwardDiff.jl/pull/481, to never take the measure-zero branch, and hence always give NaNs.

Zygote's behaviour (using a rule for norm from ChainRules, here) does not pick one of these, BTW:

julia> Zygote.gradient(x -> sqrt(sum(abs2, x)), [0,0])
([NaN, NaN],)

julia> Zygote.gradient(norm, [0,0])
([0.0, 0.0],)

That could choose to pick a particular "sub-gradient". But I'm not sure that doing that with DiffRules would work, since the function takes a vector not a scalar.

mcabbott avatar Oct 07 '21 19:10 mcabbott

It seems this particular issue is due to the derivative for sqrt assuming the input is nonzero.

julia> ForwardDiff.gradient(x -> sqrt(sum(abs2, x)), [0,0])  # bad
2-element Vector{Float64}:
 NaN
 NaN

julia> ForwardDiff.gradient(x -> sum(abs2, x), [0,0]) # removing the sqrt, fine
2-element Vector{Int64}:
 0
 0

julia> sqrt(ForwardDiff.Dual(0.0, 0.0)) # Uh-oh! should be zero!
Dual{Nothing}(0.0,NaN)

julia> ForwardDiff.Dual(0.0, 0.0)^0.5 # yes, pow has the same issue
Dual{Nothing}(0.0,NaN)

I think the right fix here is to bypass DiffRules' rules for ^, sqrt, and cbrt to return 0 in these cases, since it can't take into account the value of the input derivative.

See also https://github.com/JuliaDiff/ChainRules.jl/issues/576

sethaxen avatar Jan 19 '22 13:01 sethaxen

Oh right, well spotted. The dual part is indeed zero:

julia> sqrt_show(x) = sqrt(@show x);

julia> ForwardDiff.gradient(x -> sqrt_show(sum(abs2, x)), [0,0])
x = Dual{ForwardDiff.Tag{var"#5#6", Int64}}(0,0,0)
2-element Vector{Float64}:
 NaN
 NaN

Catching that and replacing it by zero can avoid the NaN, but seems slightly weird in that it doesn't pick a subgradient (if that's the term). While [0, eps()] gives you a vector of length 1, telling you that moving away from zero will change the function, this would imply that the function is flat there.

We can picture this (for a 2-vector) as being the tip of a cone, a rolled up piece of paper cone. The current behaviour is to say that the tip is a singularity, hence has no gradient. The sqrt proposal would smooth the tip off, so that microscopic angels can treat it as a dance floor. If you could always pick [0, eps()], then instead you are always informed that you can go downhill, and a random direction is provided, but the rate at which you will go downhill is correct.

Maybe that's not so strange then. It's a bit like derivative(abs, 0) == 0, rounding off the singularity. While [0, eps()] is like derivative(abs, 0) == 1. At present CR picks the former, ForwardDiff the latter.

mcabbott avatar Jan 19 '22 13:01 mcabbott

FWIW, finite differences finds that the gradient of norm at a zero vector is zero:

julia> using FiniteDifferences, LinearAlgebra

julia> FiniteDifferences.grad(central_fdm(5, 1), norm, [0.0,0.0])
([-3.8050255348364236e-17, -3.8050255348364236e-17],)

So I would be inclined to side with CR here.

But I think either way, the sqrt, cbrt, ^ issue is bad behavior that needs to be addressed.

sethaxen avatar Jan 19 '22 13:01 sethaxen

Finite differencing just averages the two sides of abs, right? If it were doing anything smarter I'd give it less weight...

mcabbott avatar Jan 19 '22 14:01 mcabbott

As a side remark: As noted above, everything works if one uses the NaN-safe setting (https://juliadiff.org/ForwardDiff.jl/dev/user/advanced/#Fixing-NaN/Inf-Issues). I think one should benchmark if and to what extent performance is affected by making NaN-safe the default. I think it's weird to report incorrect results with the default setting.

devmotion avatar Jan 19 '22 14:01 devmotion

Great. I see that the implementation is here:

https://github.com/JuliaDiff/ForwardDiff.jl/blob/43ef860cbb20e606a5f835db0b470855c0eb22b7/src/partials.jl#L93-L121

Is it obvious why this test !isfinite(x) && iszero(partials) not just iszero(partials)? And if both are needed, why && not &? Strange to put a branch inside an ifelse. Maybe the idea is that with 12 partials & one x, it's quicker to check one & branch?

mcabbott avatar Jan 19 '22 14:01 mcabbott