SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

DiscreteProblem adjoints with scalar states

Open leventov opened this issue 3 years ago • 3 comments

function loss1(p)
  f(x,p,t) = 1
  prob = DiscreteProblem(f, 0, (1,10), p)
  sol = solve(prob, FunctionMap(scale_by_time = true), saveat=[1,2,3])
  return sum(sol)
end
DiffEqFlux.sciml_train(loss1,[1],ADAM(0.05),maxiters = 20)

Error in line https://github.com/SciML/DiffEqSensitivity.jl/blob/c31e529853c2c45cab583e03b08c24af99c50c2a/src/adjoint_common.jl#L41:

type DiscreteFunction has no field mass_matrix

leventov avatar Jan 01 '21 11:01 leventov

We don't have an overload for DiscreteProblem yet. @ChrisRackauckas I think for DiscreteProblems we need AD pass through.

YingboMa avatar Jan 01 '21 16:01 YingboMa

Yes, we should probably just make DiscreteProblem default to SensitivityADPassThrough() and it should be fine.

ChrisRackauckas avatar Jan 01 '21 16:01 ChrisRackauckas

Zygote won't work, so I setup the system to have ReverseDiff as the default. https://github.com/SciML/DiffEqSensitivity.jl/pull/371 fixes most of the cases.

using OrdinaryDiffEq, Zygote, Test

function loss1(p;sensealg=nothing)
  f(x,p,t) = [p[1]]
  prob = DiscreteProblem(f, [0.0], (1,10), p)
  sol = solve(prob, FunctionMap(scale_by_time = true), saveat=[1,2,3])
  return sum(sol)
end
dp1 = Zygote.gradient(loss1,[1.0])[1]
dp2 = Zygote.gradient(x->loss1(x,sensealg=TrackerAdjoint()),[1.0])[1]
dp3 = Zygote.gradient(x->loss1(x,sensealg=ReverseDiffAdjoint()),[1.0])[1]
dp4 = Zygote.gradient(x->loss1(x,sensealg=ForwardDiffSensitivity()),[1.0])[1]
@test dp1 == dp2
@test dp1 == dp3
@test dp1 == dp4

But this case with scalar values is still going to give any adjoint an issue. Fortunately, you should never be applying reverse mode on a scalar valued ODE so we're good.

ChrisRackauckas avatar Jan 02 '21 03:01 ChrisRackauckas