SciMLSensitivity.jl
SciMLSensitivity.jl copied to clipboard
ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 4 and 5")
Moving alg and sensealg from solve to ODEProblem makes an error.
code:
using Flux
using ReverseDiff
using OrdinaryDiffEq
using SciMLSensitivity
solvealg_test = Tsit5()
sensealg_test = InterpolatingAdjoint()
tspan = (0.0, 1.0)
nn = Chain(
Dense(4, 4, tanh),
)
p0, re = Flux.destructure(nn)
u0 = rand(4, 8)
f_aug(u, p, t) = re(p)(u)
function loss(p)
prob = ODEProblem(f_aug, u0, tspan, p; alg=solvealg_test, sensealg=sensealg_test)
sol = solve(prob)
sum(sol[:, :, end])
end
#= works fine
function loss(p)
prob = ODEProblem(f_aug, u0, tspan, p)
sol = solve(prob, solvealg_test; sensealg=sensealg_test)
sum(sol[:, :, end])
end
=#
res1 = loss(p0)
res2 = ReverseDiff.gradient(loss, p0)
error:
ERROR: LoadError: DimensionMismatch("arrays could not be broadcast to a common size; got a dimension with lengths 4 and 5")
Stacktrace:
[1] _bcs1
@ .\broadcast.jl:516 [inlined]
[2] _bcs
@ .\broadcast.jl:510 [inlined]
[3] broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64}})
@ Base.Broadcast .\broadcast.jl:504
[4] combine_axes
@ .\broadcast.jl:499 [inlined]
[5] _axes
@ .\broadcast.jl:224 [inlined]
[6] axes
@ .\broadcast.jl:222 [inlined]
[7] copy(bc::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}, Nothing, typeof(ReverseDiff._add_to_deriv!), Tuple{Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, Matrix{Float64}, Vector{Float64}, ChainRulesCore.NoTangent}}})
@ Base.Broadcast .\broadcast.jl:1072
[8] materialize
@ .\broadcast.jl:860 [inlined]
[9] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Tuple{SciMLSensitivity.var"#adjoint_sensitivity_backpass#253"{Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, Nothing, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, Vector{Float32}, SciMLBase.ReverseDiffOriginator, Tuple{}, Colon, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLSensitivity.var"##solve_up#438#364"{SciMLSensitivity.var"##solve_up#438#363#365"}, NamedTuple{(), Tuple{}}}})
@ SciMLSensitivity C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\macros.jl:218
[10] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(DiffEqBase.solve_up), Tuple{ODEProblem{Matrix{Float64}, Tuple{Float64, Float64}, false, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ODEFunction{false, typeof(f_aug), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLBase.StandardODEProblem}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}}, ReverseDiff.TrackedArray{Float64, Float64, 3, Array{Float64, 3}, Array{Float64, 3}}, Tuple{SciMLSensitivity.var"#adjoint_sensitivity_backpass#253"{Base.Pairs{Symbol, SciMLBase.AbstractSciMLAlgorithm, Tuple{Symbol, Symbol}, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, Nothing, InterpolatingAdjoint{0, true, Val{:central}, Nothing}, Matrix{Float64}, Vector{Float32}, SciMLBase.ReverseDiffOriginator, Tuple{}, Colon, NamedTuple{(:alg, :sensealg), Tuple{Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static.False}, InterpolatingAdjoint{0, true, Val{:central}, Nothing}}}}, SciMLSensitivity.var"##solve_up#438#364"{SciMLSensitivity.var"##solve_up#438#363#365"}, NamedTuple{(), Tuple{}}}})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\tape.jl:93
[11] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\tape.jl:87
[12] reverse_pass!
@ C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:36 [inlined]
[13] seeded_reverse_pass!(result::Vector{Float32}, output::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, input::ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, tape::ReverseDiff.GradientTape{typeof(loss), ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\utils.jl:31
[14] seeded_reverse_pass!(result::Vector{Float32}, t::ReverseDiff.GradientTape{typeof(loss), ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\tape.jl:47
[15] gradient(f::Function, input::Vector{Float32}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float32, Float32, 1, Vector{Float32}, Vector{Float32}}})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\gradients.jl:24
[16] gradient(f::Function, input::Vector{Float32})
@ ReverseDiff C:\Users\Hossein Pourbozorg\.julia\packages\ReverseDiff\5MMPp\src\api\gradients.jl:22
[17] top-level scope
@ C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\br-4.jl:32
[18] include(fname::String)
@ Base.MainInclude .\client.jl:451
[19] top-level scope
@ REPL[22]:1
[20] top-level scope
@ C:\Users\Hossein Pourbozorg\.julia\packages\CUDA\DfvRa\src\initialization.jl:52
in expression starting at C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\br-4.jl:32
Status `C:\Users\Hossein Pourbozorg\Code Projects\Mine\bug-report\br-4\Project.toml`
[587475ba] Flux v0.13.4
[1dea7af3] OrdinaryDiffEq v6.19.1
[37e2e3b7] ReverseDiff v1.14.1
[1ed8b502] SciMLSensitivity v7.2.0
I tried other AD: Zygote, ForwardDiff, Tracker, FiniteDifferences, FiniteDiff
all of them work without error.