dfdx icon indicating copy to clipboard operation
dfdx copied to clipboard

Add RNN and LSTM modules in `nn`

Open coreylowman opened this issue 2 years ago • 6 comments

Should have tuple input and outputs with hidden state.

coreylowman avatar Sep 30 '22 20:09 coreylowman

Still kinda confused on how training would happen/how the gradients would be recorded.

Afaik the other deep learning libraries expand it into a higher dimension tensor that contains the output for each step, so you'd need a fold somewhere.

Dimev avatar Oct 03 '22 21:10 Dimev

Gradient tracking wouldn't work differently than the other stuff in nn. The inputs would just be different. For example pytorch outlines a couple of cases in their lstm docs https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html:

  1. Give the input (N, L, H_in) and no hidden states -> the hidden states are initialized to zero and after running the lstm they are returned
  2. Give the input (N, L, H_in) and hidden states -> use hidden states as passed in and still return the new ones after running the lstm
  3. Both of the above for unbatched input (without the first dimension)

for dfdx this would just be multiple impl Module<...>. pytorch and dfdx would both require that all items in the batch have the same sequence length.

coreylowman avatar Oct 03 '22 23:10 coreylowman

I am not that familiar with RNNs but is it really necessary to manually provide the hidden state and return the new one as output?

Wouldn't it just work if the Module stored the hidden state of its previous invocation internally and then used it again as its input upon the next invocation? Seems easier and more intuitive for users.

The hidden and carry state (in case of an LSTM) could optionally be returned in the output, but I don't think it is good to require the hidden state in the input unless strictly necessary by nature.

TimerErTim avatar Feb 03 '23 07:02 TimerErTim

That's the standard way to do it. Usually external user code handles the storing of hidden state. Typically during training you just pass in the batch of hidden states, and during inference is when you need to actually store the hidden state.

coreylowman avatar Feb 07 '23 15:02 coreylowman

I may be wrong but I think neither PyTorch nor Tensorflow requires a hidden state as layer input. You can optionally provide it nonetheless to tell the layer what its hidden state should be initialised to.

As for user managed storage, returning the new hidden state as layer output seems reasonable and standard practice as both earlier mentioned frameworks do it that way.

I think ideally dfdx should also follow a similar approach, as both manually storing and providing the hidden state on each model invocation seems very messy and exhausting.

Therefore I propose to require the hidden state upon RNN Layer creations (with f.e. rnn::new(..., initial_hidden_state, ...) but remove it as layer input parameter.

If there is an error in my thoughts I'd really like you to point them out. That's the only way to learn :) Anyway thank you for investigating my proposal.

TimerErTim avatar Feb 09 '23 15:02 TimerErTim

LSTM is one of few major module types still missing

hsn10 avatar Dec 26 '23 18:12 hsn10