SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
JacVecOperator differentiation issues
Issues with JacVecOperator and LinSolveGMRES() when performing sensitivity analysis. Also a more general ReverseDiff issue?
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
sum(solve(_prob,solver,saveat=0.1))
end
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ReverseDiff.gradient(sum_of_solution,[u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false,autojacvec=false))
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
using Sundials
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false))
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
(tester) pkg> status
Status `~/Koofr/tester/Project.toml`
[9fdde737] DiffEqOperators v4.30.0
[41bf760c] DiffEqSensitivity v6.55.3
[f6369f11] ForwardDiff v0.10.18
[961ee093] ModelingToolkit v5.25.1
[1dea7af3] OrdinaryDiffEq v5.60.1
[37e2e3b7] ReverseDiff v1.9.0
[c3572dad] Sundials v4.5.3
[e88e6eb3] Zygote v0.6.17
The vast majority of this issue was actually https://github.com/SciML/DiffEqBase.jl/pull/685 , i.e. it was using the automatic sensealg instead of realizing it should be grabbed from prob
. With that handled, most issues are either fixed or upstream:
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
sum(solve(_prob,solver,saveat=0.1))
end
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ReverseDiff.gradient(sum_of_solution,[u0;p]) # Expected failure
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # LinearAlgebra.givensAlgorithm not differentiable
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false,autojacvec=false))
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # LinearAlgebra.givensAlgorithm not differentiable
Zygote.gradient(sum_of_solution,[u0;p])
Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
using Sundials
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # Expected failure
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false))
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
So the Sundials things all work (except ForwardDiff on Sundials, but that's expected). Zygote gradients mostly work. ForwardDiff with GMRES fails because LinearAlgebra.givenAlgorithm is non-generic (requires AbstractFloat or Complex), so that's a Julia Base issue (that would be worth opening an issue about).
So that narrows down the issue to:
using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
_prob = remake(prob,u0=x[1:2],p=x[3:end])
sum(solve(_prob,solver,saveat=0.1))
end
Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
All JacVecOperator regressions.
It comes from the "corrected" handling of jac_prototype
. Hmm, I think the best thing is just to fix it so that the user does not have to pass the jac_prototype and it automatically knows how to do the lazy form. @YingboMa let's work on this soon because with QNDF and FBDF doing so well, fully removing CVODE_BDF(linear_solver=:GMRES)
is very close now.