Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

RNNs redesign

Open CarloLucibello opened this issue 4 months ago • 3 comments

A complete rework of our recurrent layers, making them more similar to their pytorch counterpart. This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR). Hopefully, this ends the infinite source of troubles that the recurrent layers have been.

  • Recur is no more. Mutating its internal state was a source of problems for AD (#2185)
  • Now RNNCell is exported and takes care of the minimal recursion step, i.e. a single time:
    • has forward cell(x , h)
    • x can be of size in or in x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out or out x batch_size
  • RNN instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:
    • has forward rnn(x, h)
    • x can be of size in x len or in x len x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out x len or out x len x batch_size
  • LSTM and GRU are similarly changed.

Close #2185, close #2341, close #2258, close #1547, close #807, close #1329

Related to #1678

PR Checklist

  • [x] cpu tests
  • [x] gpu tests
  • [x] if hidden state not given as input, assumed to be zero
  • [x] port LSTM and GRU
  • [ ] Entry in NEWS.md
  • [x] Remove reset!
  • [x] Docstrings
  • [ ] add an option in constructors to have trainable initial state
  • [ ] Benchmarks
  • [ ] use cuDNN (future PR)
  • [ ] implement the num_layers argument for stacked RNNs (future PR)
  • [ ] revisit whole documentation (future PR)
  • [ ] add dropout (future PR)

CarloLucibello avatar Oct 14 '24 21:10 CarloLucibello