DiffEqFlux.jl
DiffEqFlux.jl copied to clipboard
Ensemble backpropagation error
I'm not certain if this is a DiffEqFlux error or a Zygote error, considering it triggers depending on some seemingly harmless changes to the source code. MWE:
julia> using DiffEqFlux, DifferentialEquations
using DiffEqFlux: Zygote
model = Chain(
Dense(3, 3),
)
p, re = Flux.destructure(model)
solver = Tsit5()
tspan = (0.0, 1e-5)
prob_node = ODEProblem((u,p,t)->re(p)(u), [1.0, 0.0, 0.0], tspan, p)
function problem_generator(progenitor, idx, args...)
return remake(
progenitor,
p=p,
)
end
Zygote.gradient(Params([p])) do
sols = solve(
EnsembleProblem(prob_node, prob_func=problem_generator),
Tsit5(),
EnsembleSerial();
trajectories=1,
)
group_predictions = [Array(s) for s in sols]
loss = sum(sum(abs2.(group_predictions[i]) for i in 1:length(sols)))
@show loss
return loss
end
This results in:
loss = 1.9999973954799817
ERROR: MethodError: no method matching +(::Matrix{Float64}, ::Float64)
For element-wise addition, use broadcasting with dot syntax: array .+ scalar
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
+(::Union{VectorizationBase.AbstractMask{W, U} where U<:Union{UInt128, UInt16, UInt32, UInt64, UInt8}, VectorizationBase.VecUnroll{var"#s3", W, SIMDTypes.Bit, var"#s2"} where {var"#s3", var"#s2"<:(VectorizationBase.AbstractMask{W, U} where U<:Union{UInt128, UInt16, UInt32, UInt64, UInt8})}} where W, ::Union{Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, SIMDTypes.Bit}) at /home/sabae/.julia/packages/VectorizationBase/OEl8L/src/base_defs.jl:285
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at /home/sabae/.julia/packages/InitialValues/P5PLf/src/InitialValues.jl:153
...
Stacktrace:
[1] generic_matvecmul!(C::Vector{Any}, tA::Char, A::Matrix{Float64}, B::Vector{Any}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:747
[2] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:129 [inlined]
[3] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *(adjA::LinearAlgebra.Adjoint{Float64, Matrix{Float64}}, x::Vector{Any})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:108
[5] (::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{T} where T}, Vector{Float64}})(i::Int64)
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/5D9Rg/src/concrete_solve.jl:533
[6] _mapreduce(f::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{T} where T}, Vector{Float64}}, op::typeof(Base.add_sum), #unused#::IndexLinear, A::Base.OneTo{Int64})
@ Base ./reduce.jl:408
[7] _mapreduce_dim(f::Function, op::Function, #unused#::Base._InitialValue, A::Base.OneTo{Int64}, #unused#::Colon)
@ Base ./reducedim.jl:318
[8] #mapreduce#672
@ ./reducedim.jl:310 [inlined]
[9] mapreduce
@ ./reducedim.jl:310 [inlined]
[10] #_sum#682
@ ./reducedim.jl:878 [inlined]
[11] _sum
@ ./reducedim.jl:878 [inlined]
[12] #sum#680
@ ./reducedim.jl:874 [inlined]
[13] sum(f::Function, a::Base.OneTo{Int64})
@ Base ./reducedim.jl:874
[14] (::DiffEqSensitivity.var"#252#258"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}})()
@ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/5D9Rg/src/concrete_solve.jl:524
[15] unthunk
@ ~/.julia/packages/ChainRulesCore/7ZiwT/src/tangent_types/thunks.jl:194 [inlined]
[16] wrap_chainrules_output
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:104 [inlined]
[17] map(f::typeof(Zygote.wrap_chainrules_output), t::Tuple{ChainRulesCore.Thunk{DiffEqSensitivity.var"#252#258"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}, ChainRulesCore.Thunk{DiffEqSensitivity.var"#250#256"{0, Vector{Vector{T} where T}, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}, ChainRulesCore.NoTangent})
@ Base ./tuple.jl:215
[18] map (repeats 3 times)
@ ./tuple.jl:216 [inlined]
[19] wrap_chainrules_output
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:105 [inlined]
[20] ZBack
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/chainrules.jl:204 [inlined]
[21] (::Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#forward_sensitivity_backpass#255"{0, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}})(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
[22] (::Zygote.var"#1733#back#210"{Zygote.var"#208#209"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, Zygote.ZBack{DiffEqSensitivity.var"#forward_sensitivity_backpass#255"{0, Vector{Float64}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float32}, ODEFunction{false, var"#65#66", LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, ForwardDiffSensitivity{0, nothing}, Vector{Float64}, Vector{Float32}, Tuple{}, UnitRange{Int64}}}}})(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[23] Pullback
@ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:73 [inlined]
[24] (::typeof(∂(#solve#43)))(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[25] (::Zygote.var"#208#209"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#43))})(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
[26] #1733#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[27] Pullback
@ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:68 [inlined]
[28] (::typeof(∂(solve)))(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[29] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:143 [inlined]
[30] (::typeof(∂(#batch_func#460)))(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[31] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:139 [inlined]
[32] (::typeof(∂(batch_func)))(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[33] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:195 [inlined]
[34] (::typeof(∂(λ)))(Δ::Vector{Vector{Any}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[35] (::DiffEqBase.var"#186#194")(f::typeof(∂(λ)), δ::Vector{Vector{Any}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/b1nST/src/init.jl:128
[36] responsible_map(::Function, ::Vector{typeof(∂(λ))}, ::Vararg{Any, N} where N)
@ SciMLBase ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:188
[37] (::DiffEqBase.var"#∇responsible_map_internal#193"{Vector{typeof(∂(λ))}})(Δ::EnsembleSolution{Vector{Any}, 2, Vector{Vector{Vector{Any}}}})
@ DiffEqBase ~/.julia/packages/DiffEqBase/b1nST/src/init.jl:128
[38] #168#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[39] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:194 [inlined]
[40] (::typeof(∂(#solve_batch#464)))(Δ::EnsembleSolution{Vector{Any}, 2, Vector{Vector{Vector{Any}}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
--- the last 2 lines are repeated 1 more time ---
[43] macro expansion
@ ./timing.jl:287 [inlined]
[44] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
[45] (::typeof(∂(#__solve#459)))(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[46] Pullback
@ ~/.julia/packages/SciMLBase/7GnZA/src/ensemble/basic_ensemble_solve.jl:103 [inlined]
[47] (::typeof(∂(__solve##kw)))(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[48] #208
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203 [inlined]
[49] (::Zygote.var"#1733#back#210"{Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(__solve##kw))}})(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[50] Pullback
@ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:101 [inlined]
[51] (::typeof(∂(#solve#45)))(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[52] (::Zygote.var"#208#209"{Tuple{Tuple{Nothing, Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(#solve#45))})(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:203
[53] #1733#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[54] Pullback
@ ~/.julia/packages/DiffEqBase/b1nST/src/solve.jl:98 [inlined]
[55] (::typeof(∂(solve##kw)))(Δ::Array{Any, 3})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[56] Pullback
@ ./REPL[19]:20 [inlined]
[57] (::typeof(∂(#67)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[58] (::Zygote.var"#89#90"{Params, typeof(∂(#67)), Zygote.Context})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:356
[59] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:76
[60] top-level scope
@ REPL[19]:19
Which, as far as I can tell, is due to the fact that v is a matrix and not a vector, as we would expect, right here: https://github.com/SciML/DiffEqSensitivity.jl/blob/b46fab37b6f1c3b9eda8a407ae60d6f87ea261b9/src/concrete_solve.jl#L302
The crazy thing is that if we change:
group_predictions = [Array(s) for s in sols]
loss = sum(sum(abs2.(group_predictions[i]) for i in 1:length(sols)))
to just:
loss = sum(sum(abs2.(Array(sols[i])) for i in 1:length(sols)))
Then it all works. This makes me think it might be a Zygote bug, but I figured I'd ask here first.
Environment
(jl_bmACXe) pkg> st
Status `/tmp/jl_bmACXe/Project.toml`
[aae7a2af] DiffEqFlux v1.44.0
[0c46a032] DifferentialEquations v6.20.0
Which, as far as I can tell, is due to the fact that v is a matrix and not a vector, as we would expect, right here:
Yeah, I would expect a vector of vectors to pull back as a vector of vectors, but Zygote must be adding some projection somewhere?
@DhairyaLGandhi
Yeah persevering the input data structure was one of the reasons I think projection can be surprising.
Can you help us narrow down what changed to cause this projection?
Sure, I'll take a look
@DhairyaLGandhi did this get fixed?
Tests are passing on these these days.