Flux.jl
Flux.jl copied to clipboard
Initial state in RNNs should not be learnable by default
Currently, the RNN cells are initialised as param
's. (e.g. here and here). This causes the initial state to be modified during the backprop, which can in turn affect the model when reset!
is called.
The default behaviour of the initial cell state should be for it to stay constant and not to be affected by the backprop. Having it learned, as per now, is still useful in some contexts, so this should stay as an option
Has there been any updates on this? When I'm using RNNs with Flux.reset!
, there are gradients like mentioned at the end of #808. My workaround is to just keep a copy of the initial hidden state around outside the recur cell and use that to reset the hidden state. (This is in Zygote).
@jeremiedb has this been fixed?
Treating the initial state as learnable parameters is still the default behavior for RNN, nothing was changed in the latest PR.
My position on the subject is however that the initial state should continue to be treated as learnable parameters. It's debatable whether one case is more prevalent to the other, on my end, for NLP or time-series, learnable has been the desired case. The more objective argument I would have, is that the CUDNN RNN handles initial states as learnable parameters, and as such I think it makes it an expected default behavior. Adding that ignoring the initial state as a learnable parameter is fairly trivial, I have some difficulty seeing how changing the current behavior brings improvement.
I could perhaps add a quick section in the docs about that initial state handling, given that I skipped discussing explicitly that question.
Actually, an option to make learnable parameters a more first class citizen could be to use to same approach taken with the bias
option for Dense
layer. So we could add learnable_init=true
option in the RNN cells for that purpose. It would result in the initial state to be set to Zeros, that is non-learnable zeros if learnable_init=false
.
Unless the are some cases in which one needs a non-learnable non-zero initial state, it seems a good solution. That and showing in the docs some examples along the lines of
# Non-trainable state0 for all RNN cell
trainable(m::RNNcell) = (m.Wi, m.Wh, m.b)
# Or in alternative
# exclude state0 from params only for a specific cell
ps = Flux.params(m)
delete!(ps, m.state0)
should solve the problem