SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

JacVecOperator differentiation issues

Open ArnoStrouwen opened this issue 3 years ago • 2 comments

Issues with JacVecOperator and LinSolveGMRES() when performing sensitivity analysis. Also a more general ReverseDiff issue?

using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
    du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
    du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    sum(solve(_prob,solver,saveat=0.1))
end

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ReverseDiff.gradient(sum_of_solution,[u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false,autojacvec=false))
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

using Sundials
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false))
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
Zygote.gradient(sum_of_solution,[u0;p])
(tester) pkg> status
      Status `~/Koofr/tester/Project.toml`
  [9fdde737] DiffEqOperators v4.30.0
  [41bf760c] DiffEqSensitivity v6.55.3
  [f6369f11] ForwardDiff v0.10.18
  [961ee093] ModelingToolkit v5.25.1
  [1dea7af3] OrdinaryDiffEq v5.60.1
  [37e2e3b7] ReverseDiff v1.9.0
  [c3572dad] Sundials v4.5.3
  [e88e6eb3] Zygote v0.6.17

ArnoStrouwen avatar Jul 22 '21 16:07 ArnoStrouwen

The vast majority of this issue was actually https://github.com/SciML/DiffEqBase.jl/pull/685 , i.e. it was using the automatic sensealg instead of realizing it should be grabbed from prob. With that handled, most issues are either fixed or upstream:

using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
    du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
    du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    sum(solve(_prob,solver,saveat=0.1))
end

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ReverseDiff.gradient(sum_of_solution,[u0;p]) # Expected failure
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # LinearAlgebra.givensAlgorithm not differentiable
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false,autojacvec=false))
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # LinearAlgebra.givensAlgorithm not differentiable
Zygote.gradient(sum_of_solution,[u0;p])

Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

using Sundials
prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p]) # Expected failure
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint(autodiff=false))
solver = CVODE_BDF(linear_solver=:GMRES)
sum_of_solution([u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

So the Sundials things all work (except ForwardDiff on Sundials, but that's expected). Zygote gradients mostly work. ForwardDiff with GMRES fails because LinearAlgebra.givenAlgorithm is non-generic (requires AbstractFloat or Complex), so that's a Julia Base issue (that would be worth opening an issue about).

So that narrows down the issue to:

using OrdinaryDiffEq, DiffEqSensitivity, DiffEqOperators, Zygote, ReverseDiff, ForwardDiff, ModelingToolkit
function fiip(du,u,p,t)
    du[1] = dx = p[1]*u[1] - p[2]*u[1]*u[2]
    du[2] = dy = -p[3]*u[2] + p[4]*u[1]*u[2]
end
p = [1.5,1.0,3.0,1.0]; u0 = [1.0;1.0]
function sum_of_solution(x)
    _prob = remake(prob,u0=x[1:2],p=x[3:end])
    sum(solve(_prob,solver,saveat=0.1))
end

Jv = JacVecOperator(fiip,u0,p,0.0);
fiip3 = ODEFunction(fiip;jac_prototype=Jv);
prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF()
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

prob = ODEProblem(fiip3,u0,(0.0,10.0),p,sensealg=InterpolatingAdjoint())
solver = QNDF(linsolve=LinSolveGMRES())
sum_of_solution([u0;p])
ForwardDiff.gradient(sum_of_solution,[u0;p])
Zygote.gradient(sum_of_solution,[u0;p])

All JacVecOperator regressions.

ChrisRackauckas avatar Jul 26 '21 15:07 ChrisRackauckas

It comes from the "corrected" handling of jac_prototype. Hmm, I think the best thing is just to fix it so that the user does not have to pass the jac_prototype and it automatically knows how to do the lazy form. @YingboMa let's work on this soon because with QNDF and FBDF doing so well, fully removing CVODE_BDF(linear_solver=:GMRES) is very close now.

ChrisRackauckas avatar Jul 26 '21 15:07 ChrisRackauckas