SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

Gradient of SteadyStateProblem with ArrayPartition as u0

Open garibarba opened this issue 3 years ago • 4 comments

I've been trying to track this down and I've narrowed it to this minimal version:

This works fine (computes the gradient) with u0 as an Array:

using Flux
using Statistics:mean
using DiffEqFlux, OrdinaryDiffEq
using DifferentialEquations
using DiffEqSensitivity
using RecursiveArrayTools

begin
    u0 = Float32[2.; 0.]
    dudt2 = Chain(
                x -> x.^3,
                Dense(2, 50, tanh),
                Dense(50, 2, tanh))
    p, re = Flux.destructure(dudt2) # use this p as the initial condition!
    dudt(u,p,t) = re(p)(u) # need to restrcture for backprop!
    prob = SteadyStateProblem(dudt, u0, p)

    function predict_ssp()
        Array(solve(prob, DynamicSS(Tsit5(); abstol = 1f-4, reltol = 1f-3, tspan = Inf32);
        ))
    end

    gs = Flux.gradient(() -> sum(predict_ssp()), params(p))
end

While changing u0 to an ArrayPartition with some minor adjustments crashes (trace at the end).

begin
    u0 = ArrayPartition(Float32[2.; 0.])
    dudt2 = Chain(
                x -> x.x[1],
                x -> x.^3,
                Dense(2, 10, tanh),
                Dense(10, 2, tanh),
                x -> ArrayPartition(x))
    p, re = Flux.destructure(dudt2) # use this p as the initial condition!
    dudt(u,p,t) = re(p)(u) # need to restrcture for backprop!
    prob = SteadyStateProblem(dudt, u0, p)

    predict_ssp() = Array(solve(prob, DynamicSS(Tsit5(); abstol = 1f-4, reltol = 1f-3, tspan = Inf32);))

    gs = Flux.gradient(() -> sum(predict_ssp()), params(p))
end

Trace:

parent has 2 elements, which is incompatible with size (2, 2)")
Stacktrace:
  [1] _throw_dmrs(n::Int64, str::String, dims::Tuple{Int64, Int64})
    @ Base ./reshapedarray.jl:181
  [2] _reshape
    @ ./reshapedarray.jl:176 [inlined]
  [3] reshape
    @ ./reshapedarray.jl:112 [inlined]
  [4] reshape
    @ ./reshapedarray.jl:116 [inlined]
  [5] extract_jacobian!(#unused#::Type{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}}, result::ArrayPartition{Float32, Tuple{Vector{Float32}}}, ydual::ArrayPartition{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}}}}, n::Int64)
    @ ForwardDiff ~/.julia/dev/ForwardDiff/src/jacobian.jl:115
  [6] vector_mode_jacobian(f::SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, x::ArrayPartition{Float32, Tuple{Vector{Float32}}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2, ArrayPartition{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}}}}})
    @ ForwardDiff ~/.julia/dev/ForwardDiff/src/jacobian.jl:150
  [7] jacobian(f::Function, x::ArrayPartition{Float32, Tuple{Vector{Float32}}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2, ArrayPartition{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}}}}}, ::Val{true})
    @ ForwardDiff ~/.julia/dev/ForwardDiff/src/jacobian.jl:21
  [8] jacobian(f::Function, x::ArrayPartition{Float32, Tuple{Vector{Float32}}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2, ArrayPartition{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}}}}}) (repeats 2 times)
    @ ForwardDiff ~/.julia/dev/ForwardDiff/src/jacobian.jl:19
  [9] jacobian(f::Function, x::ArrayPartition{Float32, Tuple{Vector{Float32}}}, alg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve})
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/agdxc/src/derivative_wrappers.jl:84
 [10] SteadyStateAdjointProblem(sol::SciMLBase.NonlinearSolution{Float32, 1, ArrayPartition{Float32, Tuple{Vector{Float32}}}, ArrayPartition{Float32, Tuple{Vector{Float32}}}, SteadyStateProblem{ArrayPartition{Float32, Tuple{Vector{Float32}}}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float32, Float32, Float32}, Nothing, Nothing}, sensealg::SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, g::Nothing, dg::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}}; save_idxs::Nothing)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/agdxc/src/steadystate_adjoint.jl:41
 [11] #_adjoint_sensitivities#45
    @ ~/.julia/packages/DiffEqSensitivity/agdxc/src/sensitivity_interface.jl:57 [inlined]
 [12] #adjoint_sensitivities#42
    @ ~/.julia/packages/DiffEqSensitivity/agdxc/src/sensitivity_interface.jl:6 [inlined]
 [13] steadystatebackpass
    @ ~/.julia/packages/DiffEqSensitivity/agdxc/src/concrete_solve.jl:437 [inlined]
 [14] #93#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [15] #178
    @ ~/.julia/packages/Zygote/lwmfx/src/lib/lib.jl:194 [inlined]
 [16] (::Zygote.var"#1686#back#180"{Zygote.var"#178#179"{Tuple{NTuple{4, Nothing}, Tuple{Nothing}}, DiffEqBase.var"#93#back#73"{DiffEqSensitivity.var"#steadystatebackpass#187"{Nothing, DynamicSS{Tsit5, Float32, Float32, Float32}, SteadyStateAdjoint{0, true, Val{:central}, Bool, DefaultLinSolve}, Tuple{}, SciMLBase.NonlinearSolution{Float32, 1, ArrayPartition{Float32, Tuple{Vector{Float32}}}, ArrayPartition{Float32, Tuple{Vector{Float32}}}, SteadyStateProblem{ArrayPartition{Float32, Tuple{Vector{Float32}}}, false, Vector{Float32}, ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, DynamicSS{Tsit5, Float32, Float32, Float32}, Nothing, Nothing}}}}})(Δ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
 [17] Pullback
    @ ~/.julia/packages/DiffEqBase/qntkj/src/solve.jl:70 [inlined]
 [18] (::typeof(∂(#solve#57)))(Δ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#178#179"{Tuple{NTuple{6, Nothing}, Tuple{Nothing}}, typeof(∂(#solve#57))})(Δ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/lib/lib.jl:194
 [20] #1686#back
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59 [inlined]
 [21] Pullback
    @ ~/.julia/packages/DiffEqBase/qntkj/src/solve.jl:68 [inlined]
 [22] (::typeof(∂(solve)))(Δ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
 [23] Pullback
    @ ~/project/mwe.jl:38 [inlined]
 [24] (::typeof(∂(predict_ss)))(Δ::FillArrays.Fill{Float32, 1, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface2.jl:0
 [25] Pullback
    @ ~/project/mwe.jl:40 [inlined]
 [26] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(#8)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:252
 [27] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/lwmfx/src/compiler/interface.jl:59
 [28] top-level scope
    @ ~/project/mwe.jl:40

garibarba avatar Apr 07 '21 21:04 garibarba

I've tried the same but with an ArrayPartition in an ODEProblem and it works fine, so I assume it is something specific to SteadyStateProblem.

garibarba avatar Apr 07 '21 21:04 garibarba

I've made some progress and it is looking like an issue with the similar method on ArrayPartition, and nothing in DiffEqSensitivity.

result = similar(ydual, valtype(eltype(ydual)), length(ydual), N) is returning something in size (2,) when it should be (2, 2).

More details:

> typeof(ydual)
ArrayPartition{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.UDerivativeWrapper{ODEFunction{false, typeof(dudt), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Nothing, Vector{Float32}}, Float32}, Float32, 2}}}}

> size(similar(ydual, 2,2))
(2,)

garibarba avatar Apr 07 '21 22:04 garibarba

Ok, the issue is definitely on ArrayPartition because dims are getting ignored here https://github.com/SciML/RecursiveArrayTools.jl/blob/734999e278ab5a6f9197669df155341d284b26fd/src/array_partition.jl#L31

I think we can close this here unless you want to move the issue to RecursiveArrayTools.

garibarba avatar Apr 07 '21 22:04 garibarba

This is worth keeping open here. It would be good to fix this, but it might be a little hard given what Zygote requires.

ChrisRackauckas avatar Apr 08 '21 11:04 ChrisRackauckas