nanoGPT
nanoGPT copied to clipboard
Tie LM Head Weight to Token Embedding to match official GPT2 Code
This PR updates the GPT2 lm_head weight by linking it to the token embedding weights. This is done in the official GPT2 TF implementation here.
Thank you for the PR, yes this is weight tying, a common technique https://paperswithcode.com/method/weight-tying . It reduces the number of parameters, which is probably also very helpful for distributed training overhead as there is much less to synchronize, and inside AdamW there is fewer buffers to keep track of etc. Maybe I have a minor concern around weight regularization - the embeddings are typically not regularized but the weight matrices are. We'd want to do the correct thing when we init the optimizer?
Hey there, making this change breaks my regression tests when comparing the models to the huggingface repo for: "gpt2", "gpt2-medium", and "gpt2-large".
@vgoklani what is the test exactly? forward pass?
It also breaks the current nanoGPT code, an error in configure_optimizers
Yes, I'm comparing the outputs from the forward pass for both huggingface and nanoGPT.
@karpathy With respect to weight decay, I was under the impression that the token embeddings were regularized too, and only bias terms were excluded.
Unfortunately, the only source I have for that comes from the original GPT paper (Section 4.1). The GPT2 paper say they "...largely follows the details of the OpenAI GPT model" which I took to mean they kept the same weight decay pattern.
@vgoklani Noted, I'll work on fixing this.
configure_optimizers error has been fixed.
I'm seeing matching logits outputs between NanoGPT/HF with the following snippet:
import torch
from model import GPT
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# HF
model_hf = AutoModelForCausalLM.from_pretrained("gpt2")
model_hf.eval()
# NanoGPT
model = GPT.from_pretrained('gpt2', override_args=dict(dropout=0.1))
model.eval()
# Random test sentence
test_sentence = "Our model, called GPT-2 (a successor to GPT), was trained simply to predict the next word in 40GB of Internet text."
tokens = tokenizer.encode(test_sentence, return_tensors = 'pt')
out_nano = model(tokens)[0]
out_hf = model_hf(tokens).logits
assert torch.allclose(out_nano, out_hf), "Output logits do not match :("
My apologies, we need to set the tie_word_embeddings parameter to True to match your changes. I previously set it to False to match nanoGPT.
model_huggingface = GPT2LMHeadModel.from_pretrained(
pretrained_model_name_or_path=model_type,
tie_word_embeddings=True,
)
So I tried to add support for tied weights here https://github.com/karpathy/nanoGPT/commit/7c8288552b3673574e0649e031963b8e7e8d4981 , TLDR it's just one line
self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying
and then some trickery with configure_optimizers, I triple checked everything and I think it could work, but torch.compile doesn't like it and starts spamming the following:
compiling the model... (takes a ~minute)
/home/ubuntu/miniconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py:44: UserWarning: functional_call was passed multiple values for tied weights. This behavior is deprecated and will be an error in future versions
warnings.warn("functional_call was passed multiple values for tied weights. "
/home/ubuntu/miniconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py:44: UserWarning: functional_call was passed multiple values for tied weights. This behavior is deprecated and will be an error in future versions
warnings.warn("functional_call was passed multiple values for tied weights. "
...
which looks scary and is making me uncertain. So for now I'm keeping this commit on the tie_weights branch. @soumith I wasn't able to find many docs on this. Does torch compile not play nice with tied weights perhaps?
ok i merged to master for now because things don't seem broken. will like to investigate the new warning produced separately to make sure everything is ok. Closing this PR as my implementation slightly departs from the proposal.