Enzyme-JAX
Enzyme-JAX copied to clipboard
update llama
- fix rmsnorm to compute what it should
- use complex numbers instead of rotation matrix
- don't append
Adapting this from another version that is known to run on proper data. Note that the weights for dense layers coming from pytorch are transposed and my version is using x @ weights.mT whereas this version uses weights @ x.