Diffractor.jl icon indicating copy to clipboard operation
Diffractor.jl copied to clipboard

Error with high order derivative with Flux.chain and DiffEqFlux.FastChain

Open KirillZubov opened this issue 3 years ago • 2 comments

using Flux, DiffEqFlux
using Diffractor: var"'", ∂⃖

chain = FastChain(FastDense(1,12,Flux.σ),FastDense(12,1),(x,p) ->x[1])
initθ = (DiffEqFlux.initial_params(chain))
chain_f = (x) -> chain([x], initθ)
chain_f(1.)
chain_f'(1.)
chain_f''(1.)
Internal error: encountered unexpected error in runtime:
BoundsError(a=Array{Any, (0,)}[], i=(1,))

chain = Flux.Chain(Dense(1,12,Flux.σ),Dense(12,1),(x) ->x[1])
chain([1.])
chain_f = (x) -> chain([x])
chain_f(1.)
chain_f'(1.)

chain_f''(1.)
ERROR: MethodError: no method matching +(::Tuple{ChainRulesCore.Tangent{ChainRules.var"#1201#1203"{Vector{Float64}, Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}, NamedTuple{(:ȳ, :projects), Tuple{Vector{Float64}, Tuple{ChainRulesCore.ZeroTangent}}}}, ChainRulesCore.NoTangent})
Closest candidates are:
  +(::P, ::ChainRulesCore.Tangent{P}) where P at ~/.julia/packages/ChainRulesCore/7ZiwT/src/tangent_arithmetic.jl:133
  +(::Any, ::ChainRulesCore.AbstractThunk) at ~/.julia/packages/ChainRulesCore/7ZiwT/src/tangent_arithmetic.jl:123
  +(::Any, ::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}) at ~/.julia/packages/InitialValues/P5PLf/src/InitialValues.jl:160
  julia> versioninfo()
  Julia Version 1.7.0-rc3
  Commit 3348de4ea6 (2021-11-15 08:22 UTC)
  Platform Info:
    OS: macOS (x86_64-apple-darwin19.6.0)
    CPU: Intel(R) Core(TM) i5-1038NG7 CPU @ 2.00GHz
    WORD_SIZE: 64
    LIBM: libopenlibm
    LLVM: libLLVM-12.0.1 (ORCJIT, icelake-client)
  Environment:
    JULIA_NUM_THREADS = 4

KirillZubov avatar Nov 29 '21 20:11 KirillZubov

I get a different error, for the Flux case:

julia> chain_f''(1.)
ERROR: Control flow support not fully implemented yet for higher-order reverse mode (TODO)
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] macro expansion
    @ ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:0 [inlined]
  [3] (::Diffractor.∂⃖recurse{2})(::typeof(getproperty), ::Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, ::Symbol)
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:414
  [4] (::∂⃖{2})(::typeof(getproperty), ::Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, ::Vararg{Any})
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:222
  [5] Dense
    @ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:157 [inlined]
  [6] (::Diffractor.∂⃖recurse{2})(::Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, ::Vector{Float64})
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:0
  [7] (::∂⃖{2})(f::Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, args::Vector{Float64})
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:222
  [8] applychain
    @ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:47 [inlined]
  [9] (::Diffractor.∂⃖recurse{2})(::typeof(Flux.applychain), ::Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#5#6"}, ::Vector{Float64})
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:0
 [10] (::∂⃖{2})(::typeof(Flux.applychain), ::Tuple{Dense{typeof(σ), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, var"#5#6"}, ::Vararg{Any})
    @ Diffractor ~/.julia/packages/Diffractor/HYuxt/src/stage1/generated.jl:222
 [11] Chain
    @ ~/.julia/packages/Flux/BPPNj/src/layers/basic.jl:49 [inlined]
...

@v1.8) pkg> st Flux
Status `~/.julia/environments/v1.8/Project.toml`
  [587475ba] Flux v0.12.8

Might be worth showing more lines of the error you got.

mcabbott avatar Nov 29 '21 20:11 mcabbott

@mcabbott oh, sorry. yes, Flux case gets ERROR: control flow. The previous error was probably due to an overload of Chain from pinns.jl https://github.com/JuliaDiff/Diffractor.jl/blob/be4eeb59879b7f6773746f9f3b4ef8df38ac9f99/test/pinn.jl#L16

KirillZubov avatar Nov 30 '21 13:11 KirillZubov