PaLM-jax icon indicating copy to clipboard operation
PaLM-jax copied to clipboard

Implementation of the specific Transformer architecture from PaLM - Scaling Language Modeling with Pathways - in Jax (Equinox framework)

Results 3 PaLM-jax issues
Sort by recently updated
recently updated
newest added

Hi lucidrains: Thanks for your awesome work. This is not an issue about the code, I just want to ask - did you intentionally build the model this way that...

PaLM train.py fails to run with latest equinox. ![image](https://github.com/lucidrains/PaLM-jax/assets/14943401/dcb26fd7-d1cc-42ee-89a0-871c023e1280)

In linux based or M1 based environment we need VMAP for Palm-jax