SciMLSensitivity.jl icon indicating copy to clipboard operation
SciMLSensitivity.jl copied to clipboard

`EnzymeVJP` failing for `SplitFunction`s

Open vpuri3 opened this issue 2 years ago • 2 comments

this line fails for SplitODEProblems as SplitFunctions don't have the field f.

https://github.com/SciML/DiffEqSensitivity.jl/blob/f06856ff850dcb6e566bddac133902127ac35660/src/adjoint_common.jl#L187

Error:

julia> include("examples/opt.jl")                                                              
fwd                                                                                            
1340.6743837932922                                                                             
bwd                                                                                            
ERROR: LoadError: type SplitFunction has no field f                                            
Stacktrace:                                                                                    
  [1] getproperty                                                                              
    @ ./Base.jl:38 [inlined]                                                                   
  [2] adjointdiffcache(g::DiffEqSensitivity.var"#df#239"{Matrix{Float64}, Colon}, sensealg::Int
erpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, discrete::Bool, sol::ODESolution{Float64,
 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}
, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, ComponentVector{Float32, Vector{F
loat32}, Tuple{Axis{(layer_1 = ViewAxis(1:20, Axis(weight = ViewAxis(1:10, ShapedAxis((10, 1), 
NamedTuple())), bias = ViewAxis(11:20, ShapedAxis((10, 1), NamedTuple())))), layer_2 = ViewAxis
(21:31, Axis(weight = ViewAxis(1:10, ShapedAxis((1, 10), NamedTuple())), bias = ViewAxis(11:11,
 ShapedAxis((1, 1), NamedTuple())))))}}}, SplitFunction{false, ODEFunction{false, typeof(implic
it), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothi
ng, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, ODEFuncti
on{false, typeof(explicit), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVE
D), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing, Nothing, Nothing, Nothin
g, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), No
thing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SplitODEProblem{false}},
 Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.trivial_limiter!), Static
.False}, OrdinaryDiffEq.InterpolationData{SplitFunction{false, ODEFunction{false, typeof(implic
it), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothi
ng, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, ODEFuncti
on{false, typeof(explicit), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, 
Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVE
D), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing, Nothing, Nothing, Nothin
g, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), No
thing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, OrdinaryDiff
Eq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, dg::Nothing, f::SplitFunction{fa
lse, ODEFunction{false, typeof(implicit), UniformScaling{Bool}, Nothing, Nothing, Nothing, Noth
ing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.D
EFAULT_OBSERVED), Nothing}, ODEFunction{false, typeof(explicit), UniformScaling{Bool}, Nothing,
 Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Noth
ing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothi
ng, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, t
ypeof(SciMLBase.DEFAULT_OBSERVED), Nothing}; quad::Bool, noiseterm::Bool, needs_jac::Bool)
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/Pn9H4/src/adjoint_common.jl:187

vpuri3 avatar Jun 23 '22 20:06 vpuri3

MWE -

using OrdinaryDiffEq, DiffEqSensitivity, LinearAlgebra

f1(u, p, t) = Diagonal(p) * u
f2(u, p, t) = Diagonal(p) * u

u0 = rand(4)
tsp = (0.0, 1.0)
tsv = Array(0:0.5:1.0)
#prob = ODEProblem(f1, u0, tsp)
prob = SplitODEProblem(f1, f2, u0, tsp)

function predict(p)
    solve(prob,
          Tsit5(),
          saveat=tsv,
          p=p,
#         sensealg=InterpolatingAdjoint(autojacvec=ZygoteVJP()),
          sensealg=InterpolatingAdjoint(autojacvec=EnzymeVJP()),
         ) |> Array
end

function loss(p)
    pred = predict(p)
    loss = sum(abs2, pred .- 1.0)

    loss, pred
end

p = rand(4)
println("fwd"); loss(p)[1] |> display
println("bwd"); gradient(p -> loss(p)[1], p) |> display
julia> include("examples/enzyme_fail.jl")                                            [155/1865]
fwd                                                                                            
14.868868398111983                                                                             
bwd                                                                                            
ERROR: LoadError: type SplitFunction has no field f                                            
Stacktrace:                                                                                    
  [1] getproperty                                                                              
    @ ./Base.jl:38 [inlined]                                                                   
  [2] adjointdiffcache(g::DiffEqSensitivity.var"#df#239"{Matrix{Float64}, Colon}, sensealg::Int
erpolatingAdjoint{0, true, Val{:central}, EnzymeVJP}, discrete::Bool, sol::ODESolution{Float64,
 2, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}
, ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, SplitFunction{fa
lse, ODEFunction{false, typeof(f1), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, N
othing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT
_OBSERVED), Nothing}, ODEFunction{false, typeof(f2), UniformScaling{Bool}, Nothing, Nothing, No
thing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(
SciMLBase.DEFAULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing,
 Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLB
ase.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}},
 SplitODEProblem{false}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.
trivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SplitFunction{false, ODEFunc
tion{false, typeof(f1), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Noth
ing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), 
Nothing}, ODEFunction{false, typeof(f2), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothi
ng, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DE
FAULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing, Nothing, No
thing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_
OBSERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}},
 OrdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, dg::Nothing, f::Spl
itFunction{false, ODEFunction{false, typeof(f1), UniformScaling{Bool}, Nothing, Nothing, Nothin
g, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciM
LBase.DEFAULT_OBSERVED), Nothing}, ODEFunction{false, typeof(f2), UniformScaling{Bool}, Nothing
, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Not
hing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Noth
ing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, 
typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}; quad::Bool, noiseterm::Bool, needs_jac::Bool)    
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/Pn9H4/src/adjoint_common.jl:187    
  [3] DiffEqSensitivity.ODEInterpolatingAdjointSensitivityFunction(g::Function, sensealg::Inter
polatingAdjoint{0, true, Val{:central}, EnzymeVJP}, discrete::Bool, sol::ODESolution{Float64, 2
, Vector{Vector{Float64}}, Nothing, Nothing, Vector{Float64}, Vector{Vector{Vector{Float64}}}, 
ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, false, Vector{Float64}, SplitFunction{fals
e, ODEFunction{false, typeof(f1), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Not
hing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_O
BSERVED), Nothing}, ODEFunction{false, typeof(f2), UniformScaling{Bool}, Nothing, Nothing, Noth
ing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(Sc
iMLBase.DEFAULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing, N
othing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBas
e.DEFAULT_OBSERVED), Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, S
plitODEProblem{false}}, Tsit5{typeof(OrdinaryDiffEq.trivial_limiter!), typeof(OrdinaryDiffEq.tr
ivial_limiter!), Static.False}, OrdinaryDiffEq.InterpolationData{SplitFunction{false, ODEFuncti
on{false, typeof(f1), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothin
g, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), No
thing}, ODEFunction{false, typeof(f2), UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing
, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFA
ULT_OBSERVED), Nothing}, UniformScaling{Bool}, Vector{Float64}, Nothing, Nothing, Nothing, Noth
ing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OB
SERVED), Nothing}, Vector{Vector{Float64}}, Vector{Float64}, Vector{Vector{Vector{Float64}}}, O
rdinaryDiffEq.Tsit5ConstantCache{Float64, Float64}}, DiffEqBase.DEStats}, dg::Nothing, f::Funct
ion, checkpoints::Vector{Float64}, tols::NamedTuple{(:reltol, :abstol), Tuple{Float64, Float64}
}, tstops::Nothing; noiseterm::Bool)                                                           
    @ DiffEqSensitivity ~/.julia/packages/DiffEqSensitivity/Pn9H4/src/interpolating_adjoint.jl:
72

vpuri3 avatar Jun 23 '22 21:06 vpuri3

Thanks, yes it looks like this needs a specialization.

ChrisRackauckas avatar Jun 25 '22 09:06 ChrisRackauckas