Flux.jl
Flux.jl copied to clipboard
RNNs redesign
A complete rework of our recurrent layers, making them more similar to their pytorch counterpart. This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR). Hopefully, this ends the infinite source of troubles that the recurrent layers have been.
-
Recur
is no more. Mutating its internal state was a source of problems for AD (#2185) - Now
RNNCell
is exported and takes care of the minimal recursion step, i.e. a single time:- has forward
cell(x , h)
-
x
can be of sizein
orin x batch_size
-
h
can be of sizeout
orout x batch_size
- returns
hnew
of sizeout
orout x batch_size
- has forward
-
RNN
instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:- has forward
rnn(x, h)
-
x
can be of sizein x len
orin x len x batch_size
-
h
can be of sizeout
orout x batch_size
- returns
hnew
of sizeout x len
orout x len x batch_size
- has forward
-
LSTM
andGRU
are similarly changed.
Close #2185, close #2341, close #2258, close #1547, close #807, close #1329
Related to #1678
PR Checklist
- [x] cpu tests
- [x] gpu tests
- [x] if hidden state not given as input, assumed to be zero
- [x] port
LSTM
andGRU
- [ ] Entry in NEWS.md
- [x] Remove
reset!
- [x] Docstrings
- [ ] add an option in constructors to have trainable initial state
- [ ] Benchmarks
- [ ] use
cuDNN
(future PR) - [ ] implement the
num_layers
argument for stacked RNNs (future PR) - [ ] revisit whole documentation (future PR)
- [ ] add dropout (future PR)