llama-models
llama-models copied to clipboard
JAX Llama 4 implementation
WiP; creating PR now to get some feedback.
Once ready would be interesting to see benchmarks on your same hardware with JAX as compared with PyTorch.
Hopefully you having made PyTorch doesn't make a JAX contrib unwanted!