SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 4 and 5")

Open prbzrg opened this issue 3 years ago • 1 comments

Moving alg and sensealg from solve to ODEProblem makes an error.

code:

using Flux
using ReverseDiff
using OrdinaryDiffEq
using SciMLSensitivity

solvealg_test = Tsit5()
sensealg_test = InterpolatingAdjoint()
tspan = (0.0, 1.0)

nn = Chain(
    Dense(4, 4, tanh),
)
p0, re = Flux.destructure(nn)
u0 = rand(4, 8)
f_aug(u, p, t) = re(p)(u)

function loss(p)
    prob = ODEProblem(f_aug, u0, tspan, p; alg=solvealg_test, sensealg=sensealg_test)
    sol = solve(prob)
    sum(sol[:, :, end])
end

#= works fine
function loss(p)
    prob = ODEProblem(f_aug, u0, tspan, p)
    sol = solve(prob, solvealg_test; sensealg=sensealg_test)
    sum(sol[:, :, end])
end
=#

res1 = loss(p0)
res2 = ReverseDiff.gradient(loss, p0)

error:

ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 4 and 5")
Stacktrace:
  [1] _bcs1
    @ .\broadcast.jl:516 [inlined]
  [2] _bcs
    @ .\broadcast.jl:510 [inlined]
  [3] broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64}})
    @ Base.Broadcast .\broadcast.jl:504
  [4] combine_axes
    @ .\broadcast.jl:499 [inlined]
  [5] _axes
    @ .\broadcast.jl:224 [inlined]
  [6] axes
    @ .\broadcast.jl:222 [inlined]
  [7] copy(bc::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(ReverseDiff._add_to_deriv!), Tuple{Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, Matrix{Float64}, Vector{Float64}, ChainRulesCore.NoTangent}}})
    @ Base.Broadcast .\broadcast.jl:1072
  [8] materialize
    @ .\broadcast.jl:860 [inlined]
  [9] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Tuple{SciMLSensitivity.var"#adjoint_sensitivity_backpass#253"{Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, Nothing, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, Vector{Float32}, SciMLBase.ReverseDiffOriginator, Tuple{}, Colon, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLSensitivity.var"##solve_up#438#364"{SciMLSensitivity.var"##solve_up#438#363#365"}, NamedTuple{(), Tuple{}}}})
    @ SciMLSensitivity C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\macros.jl:218
 [10] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Tuple{SciMLSensitivity.var"#adjoint_sensitivity_backpass#253"{Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, Nothing, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, Vector{Float32}, SciMLBase.ReverseDiffOriginator, Tuple{}, Colon, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLSensitivity.var"##solve_up#438#364"{SciMLSensitivity.var"##solve_up#438#363#365"}, NamedTuple{(), Tuple{}}}})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\tape.jl:93
 [11] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\tape.jl:87
 [12] reverse_pass!
    @ C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:36 [inlined]
 [13] seeded_reverse_pass!(result::Vector{Float32}, output::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, input::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, tape::ReverseDiff.GradientTape{typeof(loss), ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\utils.jl:31
 [14] seeded_reverse_pass!(result::Vector{Float32}, t::ReverseDiff.GradientTape{typeof(loss), ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:47
 [15] gradient(f::Function, input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\gradients.jl:24
 [16] gradient(f::Function, input::Vector{Float32})
    @ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\gradients.jl:22
 [17] top-level scope
    @ C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\br-4.jl:32
 [18] include(fname::String)
    @ Base.MainInclude .\client.jl:451
 [19] top-level scope
    @ REPL[22]:1
 [20] top-level scope
    @ C:\Users\Hossein Pourbozorg\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
in expression starting at C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\br-4.jl:32
      Status `C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\Project.toml`
  [587475ba] Flux v0.13.4
  [1dea7af3] OrdinaryDiffEq v6.19.1
  [37e2e3b7] ReverseDiff v1.14.1
  [1ed8b502] SciMLSensitivity v7.2.0

prbzrg avatar Jul 26 '22 22:07 prbzrg

I tried other AD: Zygote, ForwardDiff, Tracker, FiniteDifferences, FiniteDiff all of them work without error.

prbzrg avatar Jul 27 '22 19:07 prbzrg