DiffEqFlux.jl icon indicating copy to clipboard operation
DiffEqFlux.jl copied to clipboard

Ensemble backpropagation error

Open staticfloat opened this issue 4 years ago • 5 comments

I'm not certain if this is a DiffEqFlux error or a Zygote error, considering it triggers depending on some seemingly harmless changes to the source code. MWE:

julia> using DiffEqFlux, DifferentialEquations
       using DiffEqFlux: Zygote

       model = Chain(
           Dense(3, 3),
       )
       p, re = Flux.destructure(model)
       solver = Tsit5()
       tspan = (0.0, 1e-5)
       prob_node = ODEProblem((u,p,t)->re(p)(u), [1.0, 0.0, 0.0], tspan, p)

       function problem_generator(progenitor, idx, args...)
           return remake(
               progenitor,
               p=p,
          )
       end
           
       Zygote.gradient(Params([p])) do
           sols = solve(
               EnsembleProblem(prob_node, prob_func=problem_generator),
               Tsit5(),
               EnsembleSerial();
               trajectories=1,
           )
           group_predictions = [Array(s) for s in sols]
           loss = sum(sum(abs2.(group_predictions[i]) for i in 1:length(sols)))
           @show loss
           return loss
       end

This results in:

loss = 1.9999973954799817
ERROR: MethodError: no method matching +(::Matrix{Float64}, ::Float64)
For element-wise addition, use broadcasting with dot syntax: array .+ scalar
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
  +(::Union{VectorizationBase.AbstractMask{W, U} where U<:Union{UInt128, UInt16, UInt32, UInt64, UInt8}, VectorizationBase.VecUnroll{var"#s3", W, SIMDTypes.Bit, var"#s2"} where {var"#s3", var"#s2"<:(VectorizationBase.AbstractMask{W, U} where U<:Union{UInt128, UInt16, UInt32, UInt64, UInt8})}} where W, ::Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit}) at /home/sabae/.julia/packages/VectorizationBase/OEl8L/src/base_defs.jl:285
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at /home/sabae/.julia/packages/InitialValues/P5PLf/src/InitialValues.jl:153
  ...
Stacktrace:
  [1] generic_matvecmul!(C::Vector{Any}, tA::Char, A::Matrix{Float64}, B::Vector{Any}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:747
  [2] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:129 [inlined]
  [3] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [4] *(adjA::LinearAlgebra.Adjoint{Float64, Matrix{Float64}}, x::Vector{Any})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:108
  [5] (::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{T} where T}, Vector{Float64}})(i::Int64)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/5D9Rg/src/concrete_solve.jl:533
  [6] _mapreduce(f::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{T} where T}, Vector{Float64}}, op::typeof(Base.add_sum), #unused#::IndexLinear, A::Base.OneTo{Int64})
    @ Base ./reduce.jl:408
  [7] _mapreduce_dim(f::Function, op::Function, #unused#::Base._InitialValue, A::Base.OneTo{Int64}, #unused#::Colon)
    @ Base ./reducedim.jl:318
  [8] #mapreduce#672
    @ ./reducedim.jl:310 [inlined]
  [9] mapreduce
    @ ./reducedim.jl:310 [inlined]
 [10] #_sum#682
    @ ./reducedim.jl:878 [inlined]
 [11] _sum
    @ ./reducedim.jl:878 [inlined]
 [12] #sum#680
    @ ./reducedim.jl:874 [inlined]
 [13] sum(f::Function, a::Base.OneTo{Int64})
    @ Base ./reducedim.jl:874
 [14] (::DiffEqSensitivity.var"#252#258"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}})()
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/5D9Rg/src/concrete_solve.jl:524
 [15] unthunk
    @ ~/.julia/packages/ChainRulesCore/7ZiwT/src/tangent_types/thunks.jl:194 [inlined]
 [16] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:104 [inlined]
 [17] map(f::typeof(Zygote.wrap_chainrules_output), t::Tuple{ChainRulesCore.Thunk{DiffEqSensitivity.var"#252#258"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}, ChainRulesCore.Thunk{DiffEqSensitivity.var"#250#256"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}, ChainRulesCore.NoTangent})
    @ Base ./tuple.jl:215
 [18] map (repeats 3 times)
    @ ./tuple.jl:216 [inlined]
 [19] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:105 [inlined]
 [20] ZBack
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:204 [inlined]
 [21] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#forward_sensitivity_backpass#255"{0, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}})(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
 [22] (::Zygote.var"#1733#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#forward_sensitivity_backpass#255"{0, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}}})(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [23] Pullback
    @ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
 [24] (::typeof(∂(#solve#43)))(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#43))})(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
 [26] #1733#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [27] Pullback
    @ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:68 [inlined]
 [28] (::typeof(∂(solve)))(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [29] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:143 [inlined]
 [30] (::typeof(∂(#batch_func#460)))(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [31] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:139 [inlined]
 [32] (::typeof(∂(batch_func)))(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [33] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:195 [inlined]
 [34] (::typeof(∂(λ)))(Δ::Vector{Vector{Any}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [35] (::DiffEqBase.var"#186#194")(f::typeof(∂(λ)), δ::Vector{Vector{Any}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/b1nST/src/init.jl:128
 [36] responsible_map(::Function, ::Vector{typeof(∂(λ))}, ::Vararg{Any, N} where N)
    @ SciMLBase ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:188
 [37] (::DiffEqBase.var"#∇responsible_map_internal#193"{Vector{typeof(∂(λ))}})(Δ::EnsembleSolution{Vector{Any}, 2, Vector{Vector{Vector{Any}}}})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/b1nST/src/init.jl:128
 [38] #168#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [39] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:194 [inlined]
 [40] (::typeof(∂(#solve_batch#464)))(Δ::EnsembleSolution{Vector{Any}, 2, Vector{Vector{Vector{Any}}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
--- the last 2 lines are repeated 1 more time ---
 [43] macro expansion
    @ ./timing.jl:287 [inlined]
 [44] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
 [45] (::typeof(∂(#__solve#459)))(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [46] Pullback
    @ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:103 [inlined]
 [47] (::typeof(∂(__solve##kw)))(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [48] #208
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203 [inlined]
 [49] (::Zygote.var"#1733#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(__solve##kw))}})(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
 [50] Pullback
    @ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:101 [inlined]
 [51] (::typeof(∂(#solve#45)))(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [52] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(#solve#45))})(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
 [53] #1733#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [54] Pullback
    @ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:98 [inlined]
 [55] (::typeof(∂(solve##kw)))(Δ::Array{Any, 3})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [56] Pullback
    @ ./REPL[19]:20 [inlined]
 [57] (::typeof(∂(#67)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [58] (::Zygote.var"#89#90"{Params, typeof(∂(#67)), Zygote.Context})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
 [59] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
 [60] top-level scope
    @ REPL[19]:19

Which, as far as I can tell, is due to the fact that v is a matrix and not a vector, as we would expect, right here: https://github.com/SciML/DiffEqSensitivity.jl/blob/b46fab37b6f1c3b9eda8a407ae60d6f87ea261b9/src/concrete_solve.jl#L302

The crazy thing is that if we change:

group_predictions = [Array(s) for s in sols]
loss = sum(sum(abs2.(group_predictions[i]) for i in 1:length(sols)))

to just:

loss = sum(sum(abs2.(Array(sols[i])) for i in 1:length(sols)))

Then it all works. This makes me think it might be a Zygote bug, but I figured I'd ask here first.

Environment

(jl_bmACXe) pkg> st
      Status `/tmp/jl_bmACXe/Project.toml`
  [aae7a2af] DiffEqFlux v1.44.0
  [0c46a032] DifferentialEquations v6.20.0

staticfloat avatar Dec 02 '21 02:12 staticfloat

Which, as far as I can tell, is due to the fact that v is a matrix and not a vector, as we would expect, right here:

Yeah, I would expect a vector of vectors to pull back as a vector of vectors, but Zygote must be adding some projection somewhere?

@DhairyaLGandhi

ChrisRackauckas avatar Dec 02 '21 02:12 ChrisRackauckas

Yeah persevering the input data structure was one of the reasons I think projection can be surprising.

DhairyaLGandhi avatar Dec 02 '21 12:12 DhairyaLGandhi

Can you help us narrow down what changed to cause this projection?

ChrisRackauckas avatar Dec 02 '21 12:12 ChrisRackauckas

Sure, I'll take a look

DhairyaLGandhi avatar Dec 02 '21 16:12 DhairyaLGandhi

@DhairyaLGandhi did this get fixed?

ChrisRackauckas avatar Jun 11 '22 03:06 ChrisRackauckas

Tests are passing on these these days.

ChrisRackauckas avatar Nov 22 '23 16:11 ChrisRackauckas