FiniteDifferences.jl
FiniteDifferences.jl copied to clipboard
error when function may return both real and complex results
Minimal example:
julia> using FiniteDifferences
julia> using ChainRulesTestUtils: _fdm
julia> j′vp(_fdm, x -> sum(x) > 1 ? zeros(10) : zeros(ComplexF64, 10), rand(10), zeros(10))
ERROR: DimensionMismatch("second dimension of A, 20, does not match length of x, 10")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:530
[2] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:97 [inlined]
[3] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *(transA::LinearAlgebra.Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:87
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/dev/FiniteDifferences/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/dev/FiniteDifferences/src/grad.jl:73
[7] top-level scope
@ REPL[10]:1
The issue here is that jacobian
assumes that to_vec(f(x))
is always the same size if x
is perturbed, which fails if f(x)
can sometimes be real and sometimes complex. This does occur in the real world, in my case it was with eigen(M).vectors
(ref https://github.com/JuliaDiff/ChainRules.jl/blob/8ce21af2c8f8fa5dad4b1d5aaac32c7847ed1b9f/test/rulesets/LinearAlgebra/factorization.jl#L214). It might be enough to just throw a better error message here suggesting to convert the output of f
to complex numbers.