BoundsError calling Flux.reset!
ERROR: BoundsError: attempt to access Tuple{} at index [0] thrown from Zygote code when calling Flux.reset! in a Flux loss function.
Julia 1.8.0 Zygote 0.6.45 Flux v0.13.5
Test case:
using Flux
function test()
model = Dense(1 => 1)
params = Flux.params(model)
function loss(x, y)
Flux.reset!(model)
Flux.Losses.mse(model(x), y)
end
Flux.train!(loss, params, [([1.0],[1.0])], Descent())
end
Stacktrace:
ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base .\tuple.jl:29
[2] last(a::Tuple{})
@ Base .\abstractarray.jl:479
[3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(foldl), op::Base.var"#57#58"{typeof(Flux.reset!)}, x::Tuple{}; init::Nothing)
@ ChainRules C:\Users\joel\.julia\packages\ChainRules\fgVxV\src\rulesets\Base\mapreduce.jl:448
[4] chain_rrule_kw(::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::Function, ::NamedTuple{(:init,), Tuple{Nothing}}, ::Function, ::Function, ::Vararg{Any})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\chainrules.jl:230
[5] macro expansion
@ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0 [inlined]
[6] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:9
[7] _pullback
@ .\tuple.jl:555 [inlined]
[8] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[9] _pullback
@ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
[10] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[11] _pullback
@ .\abstractarray.jl:2774 [inlined]
[12] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[13] _pullback
@ C:\Users\joel\.julia\packages\Flux\EXOFx\src\layers\recurrent.jl:180 [inlined]
[14] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[15] _pullback
@ C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:23 [inlined]
[16] _pullback(::Zygote.Context{true}, ::TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Vector{Float64}, ::Vector{Float64})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[17] _apply(::Function, ::Vararg{Any})
@ Core .\boot.jl:816
[18] adjoint
@ C:\Users\joel\.julia\packages\Zygote\qGFGD\src\lib\lib.jl:203 [inlined]
[19] _pullback
@ C:\Users\joel\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined]
[20] _pullback
@ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:120 [inlined]
[21] _pullback(::Zygote.Context{true}, ::Flux.Optimise.var"#37#40"{TryFlux.var"#loss#21"{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Vector{Float64}, Vector{Float64}}})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface2.jl:0
[22] pullback(f::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:373
[23] gradient(f::Function, args::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}})
@ Zygote C:\Users\joel\.julia\packages\Zygote\qGFGD\src\compiler\interface.jl:96
[24] macro expansion
@ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:119 [inlined]
[25] macro expansion
@ C:\Users\joel\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
[26] train!(loss::Function, ps::Zygote.Params{Zygote.Buffer{Any, Vector{Any}}}, data::Vector{Tuple{Vector{Float64}, Vector{Float64}}}, opt::Flux.Optimise.Descent; cb::Flux.Optimise.var"#38#41")
@ Flux.Optimise C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:117
[27] train!
@ C:\Users\joel\.julia\packages\Flux\EXOFx\src\optimise\train.jl:113 [inlined]
[28] test()
@ TryFlux C:\data\julia\journey\modules\lev2-util\ml\TryFlux.jl:26
[29] top-level scope
@ REPL[12]:1
Removing the call to Flux.reset! removes the error.
This is from inside the foldl rule, on an empty tuple. Which comes from this foreach on a leaf node:
reset!(m::Recur) = (m.state = m.cell.state0)
reset!(m) = foreach(reset!, functor(m)[1])
julia> Functors.functor(Dense(1 => 1).weight)[1]
()
Still fails with https://github.com/JuliaDiff/ChainRules.jl/pull/569 with this stacktrace:
julia> test()
ERROR: BoundsError: attempt to access Tuple{} at index [0]
Stacktrace:
[1] getindex(t::Tuple, i::Int64)
@ Base ./tuple.jl:29
[2] last(a::Tuple{})
@ Base ./abstractarray.jl:500
[3] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Base._InitialValue, x::Tuple{Nothing})
@ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:465
[4] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{true}}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), op::Base.var"#57#58"{typeof(Flux.reset!)}, init::Nothing, x::Tuple{})
@ ChainRules ~/.julia/packages/ChainRules/fK4AU/src/rulesets/Base/mapreduce.jl:488
[5] chain_rrule
@ ~/.julia/packages/Zygote/qGFGD/src/compiler/chainrules.jl:218 [inlined]
[6] macro expansion
@ ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0 [inlined]
[7] _pullback(::Zygote.Context{true}, ::typeof(Base.mapfoldl_impl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Nothing, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:9
[8] _pullback
@ ./reduce.jl:170 [inlined]
[9] _pullback(::Zygote.Context{true}, ::Base.var"##mapfoldl#286", ::Nothing, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[10] _pullback
@ ./reduce.jl:170 [inlined]
[11] _pullback(::Zygote.Context{true}, ::Base.var"#mapfoldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(mapfoldl), ::typeof(identity), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[12] _pullback
@ ./reduce.jl:193 [inlined]
[13] _pullback(::Zygote.Context{true}, ::Base.var"##foldl#287", ::Base.Pairs{Symbol, Nothing, Tuple{Symbol}, NamedTuple{(:init,), Tuple{Nothing}}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[14] _pullback
@ ./reduce.jl:193 [inlined]
[15] _pullback(::Zygote.Context{true}, ::Base.var"#foldl##kw", ::NamedTuple{(:init,), Tuple{Nothing}}, ::typeof(foldl), ::Base.var"#57#58"{typeof(Flux.reset!)}, ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[16] _pullback
@ ./tuple.jl:602 [inlined]
[17] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::Tuple{})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[18] _pullback
@ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
[19] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[20] _pullback
@ ./abstractarray.jl:3036 [inlined]
[21] _pullback(::Zygote.Context{true}, ::typeof(foreach), ::typeof(Flux.reset!), ::NamedTuple{(:weight, :bias, :σ), Tuple{Matrix{Float32}, Vector{Float32}, typeof(identity)}})
@ Zygote ~/.julia/packages/Zygote/qGFGD/src/compiler/interface2.jl:0
[22] _pullback
@ ~/.julia/packages/Flux/EXOFx/src/layers/recurrent.jl:180 [inlined]
[23] _pullback(ctx::Zygote.Context{true}, f::typeof(Flux.reset!), args::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}})
One possible fix is:
julia> ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})
julia> Zygote.refresh()
julia> test()
More generally should Flux be differentiating inside reset! at all?
More generally should Flux be differentiating inside
reset!at all?
My understanding of https://github.com/FluxML/Flux.jl/pull/808#issuecomment-510864610 is that it's intentional to allow initial state to be trainable, but perhaps there's another way for us to make that work.