PaLM-jax
PaLM-jax copied to clipboard
Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)
Hi lucidrains: Thanks for your awesome work. This is not an issue about the code, I just want to ask - did you intentionally build the model this way that...
PaLM train.py fails to run with latest equinox. 
In linux based or M1 based environment we need VMAP for Palm-jax