mlx
mlx copied to clipboard
Implement RNN, GRU, LSTM
Proposed changes
Implement recurrent cells and layers (Elman RNN, GRU, LSTM) in Python. Ultimately, it would probably be more efficient to implement in metal for parallelization (esp. for multi-layer and bi-directional), but in the meantime I think it is worth having a simple, be it only for benchmarking.
I have tested the implementations on small character-level language models.
Checklist
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
Actually, I will add the GRU and LSTM, wait before reviewing
I think the code is in a decent state, let me know what you think when you had time to review.
Thank you for taking the time to review the PR. The next two weeks are very packed, hence most likely I won't have time to update the code but I will finish it in the first one or two week of February, if it's okay
Thank you for taking the time to review the PR. The next two weeks are very packed, hence most likely I won't have time to update the code but I will finish it in the first one or two week of February, if it's okay
Thanks for the contribution! I look forward to seeing the updates
@jdeschena do you intend to come back to this PR? I think it could be nice to have basic RNN support but if so would be good to move this one or let someone else work on it.
Apologies for the delay, I will get it done following the previous discussions
I think the PR is up for a new round of review, I have removed the cell modules and batched the input projection instead of computing it sequentially. I have also fit character-level language models and it generated coherent content.
Looks like a format failed? Would you mind checking it?
My bad, yes, it should be fixed now.