DiffEqBase.jl
DiffEqBase.jl copied to clipboard
ForwardDiff with complex numbers ode sometimes results in NaN derivatives
Found a small issue with ForwardDiff on complex numbers ode https://github.com/SciML/DiffEqBase.jl/pull/860
The Dual of time in the ode function has NaN values. Meaning, in the ode function (u,p,t) -> ...
t
is Dual with derivative NaN
Maybe a problem with the promotion of tspan
? affecting the initial dt
value ?
Example;
using LinearAlgebra, DifferentialEquations
import ForwardDiff as FD
H0 = randn(2,2)
u0 = [1.0, 0im]
Ht(u,p,t) = (H0*u)*cos(p*t)
prob0 = ODEProblem(Ht, u0, (0.0,1.0))
function loss(p)
prob = remake(prob0; p)
sol = solve(prob)
lo = abs2(tr(first(sol.u)'last(sol.u)))
lo
end
loss(rand())
FD.derivative(loss, rand()) # NaN's
This does not show up in other cases, such as;
if the initial difference is small; e.g., Ht(u,p,t) = (H0*u)*sin(p*t)
other random initial state; u0 = randn(ComplexF64,2)
real initial state u0 = [1.0, 0.0]
use some finite dt
in ODEProblem(...; dt=1)
Also nan safe mode for ForwardDiff solve this.
Also nan safe mode for ForwardDiff solve this.
Interesting, why does that make a difference?
https://juliadiff.org/ForwardDiff.jl/v0.10.2/user/advanced.html#Fixing-NaN/Inf-Issues-1
It solves NaN of Inf derivatives with 0 perturbation, meaning it makes sure that NaN*dt
is zero when dt=0
I'm not sure where this situation happens in the DiffEq