Enzyme-JAX icon indicating copy to clipboard operation
Enzyme-JAX copied to clipboard

update llama

Open ftynse opened this issue 1 year ago • 1 comments

  • fix rmsnorm to compute what it should
  • use complex numbers instead of rotation matrix
  • don't append

ftynse avatar Jul 18 '24 09:07 ftynse

Adapting this from another version that is known to run on proper data. Note that the weights for dense layers coming from pytorch are transposed and my version is using x @ weights.mT whereas this version uses weights @ x.

ftynse avatar Jul 18 '24 09:07 ftynse