MethodOfLines.jl
MethodOfLines.jl copied to clipboard
Zygote not working for gradient w.r.t. parameters (with remake)
MWE:
using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets
using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)
# Parameters, variables, and derivatives
@parameters t x
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
u(t, 0) ~ exp(-t),
u(t, 1) ~ exp(-t) * cos(1)]
# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]
# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)
# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)
# Convert the PDE problem into an ODE problem
prob = discretize(pdesys,discretization)
function pde_solution(ps)
ps = [α => ps[1], β => ps[2]]
_prob = remake(prob, p=ps)
sum(solve(_prob, Tsit5(), saveat=0.1)[u(t,x)][end,:])
end
ADZyg = AD.ZygoteBackend()
grad = AD.gradient(ADZyg, pde_solution, rand(2))
I used AbstractDiff above, but Zygote alone gives the same error:
ERROR: MethodError: no method matching size(::IRTools.Inner.Undefined)
Closest candidates are:
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
@ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
@ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:580
size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
@ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:584
...
Stacktrace:
[1] axes(A::IRTools.Inner.Undefined)
@ Base ./abstractarray.jl:98
[2] _tryaxes(x::IRTools.Inner.Undefined)
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:188
[3] map
@ ./tuple.jl:274 [inlined]
[4] adjoint
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:322 [inlined]
[5] _pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
[6] _pullback
@ ./iterators.jl:370 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(zip), ::IRTools.Inner.Undefined, ::Vector{Float64})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[8] _pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:635 [inlined]
[9] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::IRTools.Inner.Undefined)
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[10] _pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/variables.jl:149 [inlined]
[11] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.process_p_u0_symbolic), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, ::Vector{Pair{Num, Float64}}, ::Vector{Float64})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[12] _pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:78 [inlined]
[13] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#617", ::Missing, ::Missing, ::Missing, ::Vector{Pair{Num, Float64}}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[14] _pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:52 [inlined]
[15] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:p,), Tuple{Vector{Pair{Num, Float64}}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[16] _pullback
@ ~/GitProjects/hyco/scripts/mwe.jl:40 [inlined]
[17] _pullback(ctx::Zygote.Context{false}, f::typeof(pde_solution), args::Vector{Float64})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[18] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
[19] pullback
@ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
[20] gradient(f::Function, args::Vector{Float64})
@ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
[21] top-level scope
@ REPL[7]:1
So I've modified remake in SciMLBase by essentially removing the symbolic checks here: https://github.com/SciML/SciMLBase.jl/blob/75605b1a8754bc4452761100589a6020b8e9f035/src/remake.jl#L71-L79
Once that is done, I get a new error:
ERROR: No matching function wrapper was found!
that starts at:
[7] (::ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem})(::Vector{Float64}, ::Vararg{Any})
@ SciMLBase /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/scimlfunctions.jl:2267
FunctionWrappersWrappers's first non-_call stacktrace:
[6] (::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false})(::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64}, ::Float64)
@ FunctionWrappersWrappers /Net/Groups/BGI/people/mchettouh/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10
I didn't put the entire stacktrace as it's very long.
Is linked to this issue https://github.com/SciML/SciMLSensitivity.jl/issues/794
Could it be a workaround to avoid wrap altogether with wrap = Val(false) keyword argument to solve?
https://docs.sciml.ai/MethodOfLines/stable/solutions/#Original-solution
@BernhardAhrens It's currently crashing at remake, it hasn't reached solve yet.
I've tried to pinpoint the error more clearly and got this new MWE:
using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets
using PDEBase: add_metadata!
using ModelingToolkit: get_metadata
using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)
# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
u(t, 0) ~ exp(-t),
u(t, 1) ~ exp(-t) * cos(1)]
# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]
# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)
# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)
# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan; discretization.kwargs...)
function remake_p(prob::ODEProblem, p; simpsys=simpsys)
tspan = prob.tspan
u0 = prob.u0
_f = prob.f
ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, u0, tspan, p)#, prob.problem_type )
end
function pde_solution2(ps)
ps = [α => ps[1], β => ps[2]]
_prob = remake_p(prob, ps)
sol = solve(_prob, Tsit5(), saveat=0.1)
return sum(sol[u(t,x)][end,:])
end
pde_solution2([1.2,.3]);
ADzyg = AD.ZygoteBackend()
AD.gradient(ADzyg, pde_solution2, rand(2))
This one avoids some of the Zygote errors (Forces FullSpecialize, simplified remake), but it still crashes with:
ERROR: Compiling Tuple{Type{Dict}, Vector{Pair{Num, Float64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] macro expansion
@ home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101 [inlined]
[2] _pullback(ctx::Zygote.Context{false}, f::Type{Dict}, args::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101
[3] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:633 [inlined]
[4] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Pair{Num, Float64}}, ::IRTools.Inner.Undefined)
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[5] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:694 [inlined]
[6] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##get_u0_p#578", ::Bool, ::Bool, ::Bool, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[7] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:684 [inlined]
[8] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:tofloat, :use_union, :symbolic_u0), Tuple{Bool, Bool, Bool}}, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[9] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:724 [inlined]
[10] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##process_DEProblem#579", ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Symbolics.SerialForm, ::Bool, ::Bool, ::Bool, ::Bool, ::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[11] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:707 [inlined]
[12] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[13] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:834 [inlined]
[14] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##_#586", ::Nothing, ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
[15] _pullback
@ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:827 [inlined]
[16] _pullback(::Zygote.Context{false}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
@ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
....
The crash is happening in the stdlib here: https://github.com/JuliaLang/julia/blob/b74daf501619ac4be061c67d80608c4c8822fc36/base/dict.jl#L114-L126
function Dict(kv)
try
dict_with_eltype((K, V) -> Dict{K, V}, kv, eltype(kv))
catch
if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
throw(ArgumentError("Dict(kv): kv needs to be an iterator of tuples or pairs"))
else
rethrow()
end
end
end
(Given up on remake)
After following Chris's advice here, calling solve with parameters, and using wrap=Val(false), I've reached error-parity between Zygote and ReverseDiff here: https://github.com/SciML/RecursiveArrayTools.jl/blob/0965fc1f69424b9623f1150221d64889185189a3/src/vector_of_array.jl#L113
ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved
Can't figure out a way to add a breakpoint inside the gradient calculation, would be nice if someone had an idea.
Gradient errors are really hard to debug, maybe @ChrisRackauckas knows a strategy?
I've abandoned the previous way (I was extending too many functions in std lib that weren't supposed to be extended), instead looking for a working example in ModelingToolkit to see where it's going wrong. This Lorenz system example works fine (Forced FullSpecialize to make sure they're both Specialized similarly). The types of both ODEProblems are:
MethodOfLines:
ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x558fcb15, 0xbc5f765b, 0x5dbef508, 0xc5018616, 0x76ff7ee2), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}
ModelingToolkit:
ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xdf8f3253, 0x53462d65, 0x6e40626a, 0xb391e8a4, 0x9f84fe2a), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xfad0093e, 0xafa14235, 0x0e0713c0, 0x13e5b2a3, 0x1da971c5), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#___jac#555"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xc7653a39, 0xff0cdb6a, 0x7e27b429, 0x0d9163ea, 0x97ee01d3), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xd983ef08, 0x57e02881, 0x3cbc6046, 0x1441c5e4, 0xfda2cffd), Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}
The problem appears to originate from MethodOfLines somehow, then propagates to give a weird stacktrace elsewhere in the ecosystem.
Current closest to working mwe:
using ModelingToolkit, MethodOfLines, DomainSets, OrdinaryDiffEq
using PDEBase: add_metadata!
using ModelingToolkit: get_metadata
using Zygote
using ReverseDiff
import AbstractDifferentiation as AD
using SciMLSensitivity
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)
# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2
# 1D PDE and boundary conditions
eq = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
u(t, 0) ~ exp(-t),
u(t, 1) ~ exp(-t) * cos(1)]
# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
x ∈ Interval(0.0, 1.0)]
# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)
# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)
# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob= ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan)
param_vars = [α, β]
idxs = ModelingToolkit.varmap_to_vars([param_vars[1] => 1, param_vars[2] => 2], param_vars)
test_p = [1.2, 1.4]
test_p[Int.(idxs)]
function pde_solution2(ps)
#_prob = remake_p(prob, ps)
ps = ps[Int.(idxs)]
sol = solve(prob, Tsit5(), saveat=0.1, p=ps, wrap=Val(false))
return sum(sol.u[1])
end
# using Zygote
pde_solution2([1.2,.3])
ADzyg = AD.ZygoteBackend()
function grad(ps)
AD.gradient(ADzyg, pde_solution2, ps)
end
grad(rand(2))
crashes at sum(sol.u[1]), so it actually finishes getting the solution. Note that with ReverseDiff it crashes in the solver. From the many attempts I made it seems that sol.u is somehow a Dictionary with Matrices and running getindex is crashing it. Interpolation doesn't work either (even without AD).
Infiltrator isn't helping me much as Zygote is trying to compile the macro and crashes.
the sol interface is different for pdes, if you don't care about shape pass wrap = Val(false) to the solve and this will work
I'm already passing wrap=Val(false), without it it crashes sooner.
What is the type of sol?
ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.Stats, Nothing}