ForwardDiff.jl icon indicating copy to clipboard operation
ForwardDiff.jl copied to clipboard

Derivative is wrong for this inverse quadratic form

Open colinfang opened this issue 1 year ago • 1 comments

The result is correct only if I tag the matrix to be symmetric.

using LinearAlgebra
using ForwardDiff

function f_backward(x1, x2, rho)
    cov = [
        1.0 rho;
        rho 1.0
    ]
    x = [x1, x2]
    return x' * (cov \ x)
end

function f_backward_symmetric(x1, x2, rho)
    cov = Symmetric([
        1.0 rho;
        rho 1.0
    ])
    x = [x1, x2]
    return x' * (cov \ x)
end

function f_inv(x1, x2, rho)
    cov = [
        1.0 rho;
        rho 1.0
    ]
    x = [x1, x2]
    return x' * inv(cov) * x
end

function f_inv_symmetric(x1, x2, rho)
    cov = Symmetric([
        1.0 rho;
        rho 1.0
    ])
    x = [x1, x2]
    return x' * inv(cov) * x
end

function test(rho)
    @show f_backward(0.1, 0.2, rho)
    @show f_backward_symmetric(0.1, 0.2, rho)
    @show f_inv(0.1, 0.2, rho)
    @show f_inv_symmetric(0.1, 0.2, rho)

    @show ForwardDiff.derivative(x -> f_backward(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_backward_symmetric(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_inv(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_inv_symmetric(0.1, 0.2, x), rho)
end

test(0.0)


f_backward(0.1, 0.2, rho) = 0.05000000000000001
f_backward_symmetric(0.1, 0.2, rho) = 0.05000000000000001
f_inv(0.1, 0.2, rho) = 0.05000000000000001
f_inv_symmetric(0.1, 0.2, rho) = 0.05000000000000001
ForwardDiff.derivative((x->begin
            f_backward(0.1, 0.2, x)
        end), rho) = 0.0
ForwardDiff.derivative((x->begin
            f_backward_symmetric(0.1, 0.2, x)
        end), rho) = -0.04000000000000001
ForwardDiff.derivative((x->begin
            f_inv(0.1, 0.2, x)
        end), rho) = -0.020000000000000004
ForwardDiff.derivative((x->begin
            f_inv_symmetric(0.1, 0.2, x)
        end), rho) = -0.04000000000000001

colinfang avatar Mar 08 '23 21:03 colinfang

Desired answer is all derivatives -0.04?

If so, this is fixed by https://github.com/JuliaDiff/ForwardDiff.jl/pull/481 which is available on master.

mcabbott avatar Mar 09 '23 00:03 mcabbott