gemma_pytorch
gemma_pytorch copied to clipboard
[Question] Embeddings normalization by sqrt(hidden_size)
Hello there 👋
Thanks for the repo. But I have one question: why do we need to scale up (normalize) token embeddings? https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L431-L432
Unfortunately, I cannot find an answer anywhere.