DiffEqBase.jl icon indicating copy to clipboard operation
DiffEqBase.jl copied to clipboard

ForwardDiff with complex numbers ode sometimes results in NaN derivatives

Open AmitRotem opened this issue 2 years ago • 2 comments

Found a small issue with ForwardDiff on complex numbers ode https://github.com/SciML/DiffEqBase.jl/pull/860 The Dual of time in the ode function has NaN values. Meaning, in the ode function (u,p,t) -> ... t is Dual with derivative NaN Maybe a problem with the promotion of tspan ? affecting the initial dt value ?

Example;

using LinearAlgebra, DifferentialEquations
import ForwardDiff as FD

H0 = randn(2,2)
u0 = [1.0, 0im]
Ht(u,p,t) = (H0*u)*cos(p*t)
prob0 = ODEProblem(Ht, u0, (0.0,1.0))

function loss(p)
    prob = remake(prob0; p)
    sol = solve(prob)
    lo = abs2(tr(first(sol.u)'last(sol.u)))
    lo
end

loss(rand())
FD.derivative(loss, rand()) # NaN's

This does not show up in other cases, such as;

if the initial difference is small; e.g., Ht(u,p,t) = (H0*u)*sin(p*t) other random initial state; u0 = randn(ComplexF64,2) real initial state u0 = [1.0, 0.0] use some finite dt in ODEProblem(...; dt=1) Also nan safe mode for ForwardDiff solve this.

AmitRotem avatar Jan 07 '23 02:01 AmitRotem

Also nan safe mode for ForwardDiff solve this.

Interesting, why does that make a difference?

ChrisRackauckas avatar Jan 09 '23 14:01 ChrisRackauckas

https://juliadiff.org/ForwardDiff.jl/v0.10.2/user/advanced.html#Fixing-NaN/Inf-Issues-1 It solves NaN of Inf derivatives with 0 perturbation, meaning it makes sure that NaN*dt is zero when dt=0 I'm not sure where this situation happens in the DiffEq

AmitRotem avatar Jan 20 '23 23:01 AmitRotem