SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
DiscreteProblem adjoints with scalar states
function loss1(p)
f(x,p,t) = 1
prob = DiscreteProblem(f, 0, (1,10), p)
sol = solve(prob, FunctionMap(scale_by_time = true), saveat=[1,2,3])
return sum(sol)
end
DiffEqFlux.sciml_train(loss1,[1],ADAM(0.05),maxiters = 20)
Error in line https://github.com/SciML/DiffEqSensitivity.jl/blob/c31e529853c2c45cab583e03b08c24af99c50c2a/src/adjoint_common.jl#L41:
type DiscreteFunction has no field mass_matrix
We don't have an overload for DiscreteProblem
yet. @ChrisRackauckas I think for DiscreteProblem
s we need AD pass through.
Yes, we should probably just make DiscreteProblem default to SensitivityADPassThrough() and it should be fine.
Zygote won't work, so I setup the system to have ReverseDiff as the default. https://github.com/SciML/DiffEqSensitivity.jl/pull/371 fixes most of the cases.
using OrdinaryDiffEq, Zygote, Test
function loss1(p;sensealg=nothing)
f(x,p,t) = [p[1]]
prob = DiscreteProblem(f, [0.0], (1,10), p)
sol = solve(prob, FunctionMap(scale_by_time = true), saveat=[1,2,3])
return sum(sol)
end
dp1 = Zygote.gradient(loss1,[1.0])[1]
dp2 = Zygote.gradient(x->loss1(x,sensealg=TrackerAdjoint()),[1.0])[1]
dp3 = Zygote.gradient(x->loss1(x,sensealg=ReverseDiffAdjoint()),[1.0])[1]
dp4 = Zygote.gradient(x->loss1(x,sensealg=ForwardDiffSensitivity()),[1.0])[1]
@test dp1 == dp2
@test dp1 == dp3
@test dp1 == dp4
But this case with scalar values is still going to give any adjoint an issue. Fortunately, you should never be applying reverse mode on a scalar valued ODE so we're good.