mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Implement RNN, GRU, LSTM

Open jdeschena opened this issue 1 year ago • 2 comments

Proposed changes

Implement recurrent cells and layers (Elman RNN, GRU, LSTM) in Python. Ultimately, it would probably be more efficient to implement in metal for parallelization (esp. for multi-layer and bi-directional), but in the meantime I think it is worth having a simple, be it only for benchmarking.

I have tested the implementations on small character-level language models.

Checklist

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

jdeschena avatar Dec 23 '23 16:12 jdeschena

Actually, I will add the GRU and LSTM, wait before reviewing

jdeschena avatar Dec 23 '23 17:12 jdeschena

I think the code is in a decent state, let me know what you think when you had time to review.

jdeschena avatar Dec 26 '23 17:12 jdeschena

Thank you for taking the time to review the PR. The next two weeks are very packed, hence most likely I won't have time to update the code but I will finish it in the first one or two week of February, if it's okay

jdeschena avatar Jan 13 '24 10:01 jdeschena

Thank you for taking the time to review the PR. The next two weeks are very packed, hence most likely I won't have time to update the code but I will finish it in the first one or two week of February, if it's okay

Thanks for the contribution! I look forward to seeing the updates

awni avatar Jan 13 '24 14:01 awni

@jdeschena do you intend to come back to this PR? I think it could be nice to have basic RNN support but if so would be good to move this one or let someone else work on it.

awni avatar Feb 24 '24 14:02 awni

Apologies for the delay, I will get it done following the previous discussions

jdeschena avatar Feb 25 '24 13:02 jdeschena

I think the PR is up for a new round of review, I have removed the cell modules and batched the input projection instead of computing it sequentially. I have also fit character-level language models and it generated coherent content.

jdeschena avatar Feb 26 '24 14:02 jdeschena

Looks like a format failed? Would you mind checking it?

awni avatar Feb 26 '24 15:02 awni

My bad, yes, it should be fixed now.

jdeschena avatar Feb 27 '24 11:02 jdeschena