MethodOfLines.jl icon indicating copy to clipboard operation
MethodOfLines.jl copied to clipboard

Zygote not working for gradient w.r.t. parameters (with remake)

Open Qfl3x opened this issue 2 years ago • 15 comments

MWE:

using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets

using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters t x
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
prob = discretize(pdesys,discretization)

function pde_solution(ps)
    ps = [α => ps[1], β => ps[2]]
    _prob = remake(prob, p=ps)
    sum(solve(_prob, Tsit5(), saveat=0.1)[u(t,x)][end,:])
end

ADZyg = AD.ZygoteBackend()
grad = AD.gradient(ADZyg, pde_solution, rand(2))

I used AbstractDiff above, but Zygote alone gives the same error:

ERROR: MethodError: no method matching size(::IRTools.Inner.Undefined)

Closest candidates are:
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted})
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:581
  size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer)
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:580
  size(::Union{LinearAlgebra.QRCompactWYQ, LinearAlgebra.QRPackedQ})
   @ LinearAlgebra ~/julia-1.9.0/share/julia/stdlib/v1.9/LinearAlgebra/src/qr.jl:584
  ...

Stacktrace:
  [1] axes(A::IRTools.Inner.Undefined)
    @ Base ./abstractarray.jl:98
  [2] _tryaxes(x::IRTools.Inner.Undefined)
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:188
  [3] map
    @ ./tuple.jl:274 [inlined]
  [4] adjoint
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/lib/array.jl:322 [inlined]
  [5] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:66 [inlined]
  [6] _pullback
    @ ./iterators.jl:370 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(zip), ::IRTools.Inner.Undefined, ::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [8] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:635 [inlined]
  [9] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Float64}, ::IRTools.Inner.Undefined)
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [10] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/ModelingToolkit/dkLCE/src/variables.jl:149 [inlined]
 [11] _pullback(::Zygote.Context{false}, ::typeof(SciMLBase.process_p_u0_symbolic), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, ::Vector{Pair{Num, Float64}}, ::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [12] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:78 [inlined]
 [13] _pullback(::Zygote.Context{false}, ::SciMLBase.var"##remake#617", ::Missing, ::Missing, ::Missing, ::Vector{Pair{Num, Float64}}, ::Missing, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [14] _pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/remake.jl:52 [inlined]
 [15] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:p,), Tuple{Vector{Pair{Num, Float64}}}}, ::typeof(remake), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#540"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x3eb0599d, 0xa6fca765, 0x8abae924, 0x13fbef1d, 0xea3dc025), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x7fa78367, 0x27aca4b8, 0x1bc76d77, 0x077e27ff, 0x63bf4ad3), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [16] _pullback
    @ ~/GitProjects/hyco/scripts/mwe.jl:40 [inlined]
 [17] _pullback(ctx::Zygote.Context{false}, f::typeof(pde_solution), args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [18] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:44
 [19] pullback
    @ /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:42 [inlined]
 [20] gradient(f::Function, args::Vector{Float64})
    @ Zygote /Net/Groups/BGI/people/mchettouh/.julia/packages/Zygote/4SSHS/src/compiler/interface.jl:96
 [21] top-level scope
    @ REPL[7]:1

Qfl3x avatar Sep 14 '23 13:09 Qfl3x

So I've modified remake in SciMLBase by essentially removing the symbolic checks here: https://github.com/SciML/SciMLBase.jl/blob/75605b1a8754bc4452761100589a6020b8e9f035/src/remake.jl#L71-L79

Once that is done, I get a new error:

ERROR: No matching function wrapper was found!

that starts at:

  [7] (::ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#622#generated_observed#548"{Bool, ODESystem, Dict{Any, Any}}, Nothing, ODESystem})(::Vector{Float64}, ::Vararg{Any})
    @ SciMLBase /Net/Groups/BGI/people/mchettouh/.julia/packages/SciMLBase/jNK7c/src/scimlfunctions.jl:2267

Qfl3x avatar Sep 27 '23 09:09 Qfl3x

FunctionWrappersWrappers's first non-_call stacktrace:

  [6] (::FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Pair{Num, Float64}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false})(::Vector{Float64}, ::Vector{Float64}, ::Vector{Float64}, ::Float64)
    @ FunctionWrappersWrappers /Net/Groups/BGI/people/mchettouh/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10

I didn't put the entire stacktrace as it's very long.

Qfl3x avatar Sep 27 '23 10:09 Qfl3x

Is linked to this issue https://github.com/SciML/SciMLSensitivity.jl/issues/794

BernhardAhrens avatar Sep 29 '23 09:09 BernhardAhrens

Could it be a workaround to avoid wrap altogether with wrap = Val(false) keyword argument to solve?

https://docs.sciml.ai/MethodOfLines/stable/solutions/#Original-solution

BernhardAhrens avatar Sep 29 '23 09:09 BernhardAhrens

@BernhardAhrens It's currently crashing at remake, it hasn't reached solve yet.

Qfl3x avatar Sep 29 '23 10:09 Qfl3x

I've tried to pinpoint the error more clearly and got this new MWE:

using DifferentialEquations, ModelingToolkit, MethodOfLines, DomainSets

using PDEBase: add_metadata!
using ModelingToolkit: get_metadata

using Zygote
import AbstractDifferentiation as AD
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob = ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan; discretization.kwargs...)


function remake_p(prob::ODEProblem, p; simpsys=simpsys)
    tspan = prob.tspan
    u0 = prob.u0

    _f = prob.f
    ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, u0, tspan, p)#, prob.problem_type )
end

function pde_solution2(ps)
    ps = [α => ps[1], β => ps[2]]
    _prob = remake_p(prob, ps)
    sol = solve(_prob, Tsit5(), saveat=0.1)
    return sum(sol[u(t,x)][end,:])
end


pde_solution2([1.2,.3]);
ADzyg = AD.ZygoteBackend()
AD.gradient(ADzyg, pde_solution2, rand(2))

This one avoids some of the Zygote errors (Forces FullSpecialize, simplified remake), but it still crashes with:



ERROR: Compiling Tuple{Type{Dict}, Vector{Pair{Num, Float64}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations

Stacktrace:
  [1] macro expansion
    @ home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::Type{Dict}, args::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:101
  [3] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/utils.jl:633 [inlined]
  [4] _pullback(::Zygote.Context{false}, ::typeof(ModelingToolkit.mergedefaults), ::Dict{Any, Any}, ::Vector{Pair{Num, Float64}}, ::IRTools.Inner.Undefined)
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [5] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:694 [inlined]
  [6] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##get_u0_p#578", ::Bool, ::Bool, ::Bool, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [7] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:684 [inlined]
  [8] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:tofloat, :use_union, :symbolic_u0), Tuple{Bool, Bool, Bool}}, ::typeof(ModelingToolkit.get_u0_p), ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
  [9] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:724 [inlined]
 [10] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##process_DEProblem#579", ::Bool, ::Nothing, ::Nothing, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Bool, ::Symbolics.SerialForm, ::Bool, ::Bool, ::Bool, ::Bool, ::Base.Pairs{Symbol, Real, Tuple{Symbol, Symbol, Symbol}, NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [11] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:707 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::NamedTuple{(:t, :has_difference, :check_length), Tuple{Float64, Bool, Bool}}, ::typeof(ModelingToolkit.process_DEProblem), ::Type{ODEFunction{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [13] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:834 [inlined]
 [14] _pullback(::Zygote.Context{false}, ::ModelingToolkit.var"##_#586", ::Nothing, ::Bool, ::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
 [15] _pullback
    @ home/.julia/packages/ModelingToolkit/dkLCE/src/systems/diffeqs/abstractodesystem.jl:827 [inlined]
 [16] _pullback(::Zygote.Context{false}, ::Type{ODEProblem{true, SciMLBase.FullSpecialize}}, ::ODESystem, ::Vector{Float64}, ::Tuple{Float64, Float64}, ::Vector{Pair{Num, Float64}})
    @ Zygote home/.julia/packages/Zygote/4SSHS/src/compiler/interface2.jl:0
.... 

The crash is happening in the stdlib here: https://github.com/JuliaLang/julia/blob/b74daf501619ac4be061c67d80608c4c8822fc36/base/dict.jl#L114-L126



function Dict(kv)
    try
        dict_with_eltype((K, V) -> Dict{K, V}, kv, eltype(kv))
    catch
        if !isiterable(typeof(kv)) || !all(x->isa(x,Union{Tuple,Pair}),kv)
            throw(ArgumentError("Dict(kv): kv needs to be an iterator of tuples or pairs"))
        else
            rethrow()
        end
    end
end

Qfl3x avatar Sep 29 '23 10:09 Qfl3x

(Given up on remake)

After following Chris's advice here, calling solve with parameters, and using wrap=Val(false), I've reached error-parity between Zygote and ReverseDiff here: https://github.com/SciML/RecursiveArrayTools.jl/blob/0965fc1f69424b9623f1150221d64889185189a3/src/vector_of_array.jl#L113

ERROR: ArgumentError: broadcasting over dictionaries and `NamedTuple`s is reserved

Qfl3x avatar Oct 04 '23 08:10 Qfl3x

Can't figure out a way to add a breakpoint inside the gradient calculation, would be nice if someone had an idea.

Qfl3x avatar Oct 04 '23 08:10 Qfl3x

Gradient errors are really hard to debug, maybe @ChrisRackauckas knows a strategy?

xtalax avatar Oct 04 '23 14:10 xtalax

I've abandoned the previous way (I was extending too many functions in std lib that weren't supposed to be extended), instead looking for a working example in ModelingToolkit to see where it's going wrong. This Lorenz system example works fine (Forced FullSpecialize to make sure they're both Specialized similarly). The types of both ODEProblems are:

MethodOfLines:

ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x558fcb15, 0xbc5f765b, 0x5dbef508, 0xc5018616, 0x76ff7ee2), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}

ModelingToolkit:

ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#549"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xdf8f3253, 0x53462d65, 0x6e40626a, 0xb391e8a4, 0x9f84fe2a), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xfad0093e, 0xafa14235, 0x0e0713c0, 0x13e5b2a3, 0x1da971c5), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, ModelingToolkit.var"#___jac#555"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xc7653a39, 0xff0cdb6a, 0x7e27b429, 0x0d9163ea, 0x97ee01d3), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0xd983ef08, 0x57e02881, 0x3cbc6046, 0x1441c5e4, 0xfda2cffd), Nothing}}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#630#generated_observed#559"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}

The problem appears to originate from MethodOfLines somehow, then propagates to give a weird stacktrace elsewhere in the ecosystem.

Qfl3x avatar Oct 10 '23 07:10 Qfl3x

Current closest to working mwe:

using ModelingToolkit, MethodOfLines, DomainSets, OrdinaryDiffEq

using PDEBase: add_metadata!
using ModelingToolkit: get_metadata

using Zygote
using ReverseDiff
import AbstractDifferentiation as AD
using SciMLSensitivity
# Method of Manufactured Solutions: exact solution
u_exact = (x,t) -> exp.(-t) * cos.(x)

# Parameters, variables, and derivatives
@parameters x t
@variables u(..)
@parameters α β
Dt = Differential(t)
Dxx = Differential(x)^2

# 1D PDE and boundary conditions
eq  = Dt(u(t, x)) ~(α + β) * Dxx(u(t, x))
bcs = [u(0, x) ~ cos(x),
        u(t, 0) ~ exp(-t),
        u(t, 1) ~ exp(-t) * cos(1)]

# Space and time domains
domains = [t ∈ Interval(0.0, 1.0),
           x ∈ Interval(0.0, 1.0)]

# Parameters
ps = [α => 1.2, β => 2.1]
# PDE system
@named pdesys = PDESystem(eq, bcs, domains, [t, x], [u(t, x)], ps)

# Method of lines discretization
dx = 0.1
order = 2
discretization = MOLFiniteDifference([x => dx], t)

# Convert the PDE problem into an ODE problem
sys,tspan = symbolic_discretize(pdesys,discretization)
simpsys = structural_simplify(sys)
add_metadata!(get_metadata(simpsys), sys)
prob= ODEProblem{true, SciMLBase.FullSpecialize}(simpsys, Pair[], tspan)
param_vars = [α, β]
idxs = ModelingToolkit.varmap_to_vars([param_vars[1] => 1, param_vars[2] => 2], param_vars)

test_p = [1.2, 1.4]
test_p[Int.(idxs)]
function pde_solution2(ps)
    #_prob = remake_p(prob, ps)
    ps = ps[Int.(idxs)]
    sol = solve(prob, Tsit5(), saveat=0.1, p=ps, wrap=Val(false))
    return sum(sol.u[1])
end

# using Zygote
pde_solution2([1.2,.3])
ADzyg = AD.ZygoteBackend()
function grad(ps)
    AD.gradient(ADzyg, pde_solution2, ps)
end
grad(rand(2))

crashes at sum(sol.u[1]), so it actually finishes getting the solution. Note that with ReverseDiff it crashes in the solver. From the many attempts I made it seems that sol.u is somehow a Dictionary with Matrices and running getindex is crashing it. Interpolation doesn't work either (even without AD).

Infiltrator isn't helping me much as Zygote is trying to compile the macro and crashes.

Qfl3x avatar Oct 13 '23 08:10 Qfl3x

the sol interface is different for pdes, if you don't care about shape pass wrap = Val(false) to the solve and this will work

xtalax avatar Oct 16 '23 13:10 xtalax

I'm already passing wrap=Val(false), without it it crashes sooner.

Qfl3x avatar Oct 16 '23 13:10 Qfl3x

What is the type of sol?

xtalax avatar Oct 16 '23 13:10 xtalax

ODESolution{Float64, 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, Vector{Float64}, ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, MethodOfLines.MOLMetadata{Val{true}(), MethodOfLines.DiscreteSpace{1, 1, MethodOfLines.CenterAlignedGrid}, MOLFiniteDifference{MethodOfLines.CenterAlignedGrid, MethodOfLines.ScalarizedDiscretization}, PDESystem, Base.RefValue{Any}, MethodOfLines.ScalarizedDiscretization}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{ODEFunction{true, SciMLBase.FullSpecialize, ModelingToolkit.var"#k#545"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x85c1aa6d, 0xc54674be, 0xc3a0c1ac, 0x6b84bad6, 0xd1e187c5), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x8704bc7a, 0xbf573882, 0x55a791fb, 0xb2445ff7, 0x35b3dd5b), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Vector{Symbol}, Symbol, Vector{Symbol}, ModelingToolkit.var"#637#generated_observed#555"{Bool, ODESystem, Dict{Any, Any}, Vector{SymbolicUtils.BasicSymbolic{Real}}}, Nothing, ODESystem}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiffEq.Tsit5Cache{Vector{Float64}, Vector{Float64}, Vector{Float64}, typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}}, DiffEqBase.Stats, Nothing}

Qfl3x avatar Oct 16 '23 13:10 Qfl3x