nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Jax/Flax Rewrite

Open jenkspt opened this issue 2 years ago • 3 comments

Thanks for the incredibly lucid GPT implementation!

I've started rewriting nanoGPT in Jax/Flax as a test-bed to play with the new jax.experimental.pjit API. Thought I'd put it here for anyone who's interested. https://github.com/jenkspt/gpt-jax

Also, following Jax convention -- I figured it might be reasonable to try torch.compile with the entire training update step (rather than just the forward pass) i.e https://github.com/jenkspt/gpt-jax/blob/c4e38cc35264c0eab9508bf0180c5f6e52753938/train.py#L60-L76

jenkspt avatar Jan 03 '23 02:01 jenkspt

good reading thank you for the pointer. Like the compactness of model.py and nice to discover tyro

karpathy avatar Jan 03 '23 06:01 karpathy

Hello @jenkspt , Did you see any performance improvement while using JAX?

farzanehnakhaee70 avatar Oct 12 '23 12:10 farzanehnakhaee70

The performance seems comparable - but I didn't work out the exact numbers w.r.t to time/cost for the 8x40gb A100 used in nanoGPT vs. TPU-v3-32 that I used. The Jax code achieves the loss shown in the nanoGPT README plot (2.905) in about 20 hours. A benefit of working with TPUs is that the training is trivial to scale (just use a TPU-v?-64 or TPU-v?-128 ...). It should be noted that this repo is missing a bunch of features that make it generally useful - so I don't recommend working with it unless you know what you're getting into.

jenkspt avatar Oct 16 '23 17:10 jenkspt