nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Tie LM Head Weight to Token Embedding to match official GPT2 Code

Open fattorib opened this issue 2 years ago • 8 comments

This PR updates the GPT2 lm_head weight by linking it to the token embedding weights. This is done in the official GPT2 TF implementation here.

fattorib avatar Jan 05 '23 19:01 fattorib

Thank you for the PR, yes this is weight tying, a common technique https://paperswithcode.com/method/weight-tying . It reduces the number of parameters, which is probably also very helpful for distributed training overhead as there is much less to synchronize, and inside AdamW there is fewer buffers to keep track of etc. Maybe I have a minor concern around weight regularization - the embeddings are typically not regularized but the weight matrices are. We'd want to do the correct thing when we init the optimizer?

karpathy avatar Jan 06 '23 03:01 karpathy

Hey there, making this change breaks my regression tests when comparing the models to the huggingface repo for: "gpt2", "gpt2-medium", and "gpt2-large".

vgoklani avatar Jan 06 '23 03:01 vgoklani

@vgoklani what is the test exactly? forward pass? It also breaks the current nanoGPT code, an error in configure_optimizers

karpathy avatar Jan 06 '23 03:01 karpathy

Yes, I'm comparing the outputs from the forward pass for both huggingface and nanoGPT.

vgoklani avatar Jan 06 '23 03:01 vgoklani

@karpathy With respect to weight decay, I was under the impression that the token embeddings were regularized too, and only bias terms were excluded.

Unfortunately, the only source I have for that comes from the original GPT paper (Section 4.1). The GPT2 paper say they "...largely follows the details of the OpenAI GPT model" which I took to mean they kept the same weight decay pattern.

@vgoklani Noted, I'll work on fixing this.

fattorib avatar Jan 06 '23 03:01 fattorib

configure_optimizers error has been fixed.

I'm seeing matching logits outputs between NanoGPT/HF with the following snippet:

import torch 
from model import GPT
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")

# HF
model_hf = AutoModelForCausalLM.from_pretrained("gpt2")
model_hf.eval()

# NanoGPT
model = GPT.from_pretrained('gpt2', override_args=dict(dropout=0.1))
model.eval()

# Random test sentence
test_sentence = "Our model, called GPT-2 (a successor to GPT), was trained simply to predict the next word in 40GB of Internet text."
tokens = tokenizer.encode(test_sentence, return_tensors = 'pt')

out_nano =  model(tokens)[0]
out_hf = model_hf(tokens).logits

assert torch.allclose(out_nano, out_hf), "Output logits do not match :("

fattorib avatar Jan 06 '23 04:01 fattorib

My apologies, we need to set the tie_word_embeddings parameter to True to match your changes. I previously set it to False to match nanoGPT.

model_huggingface = GPT2LMHeadModel.from_pretrained(
    pretrained_model_name_or_path=model_type,
    tie_word_embeddings=True,
)

vgoklani avatar Jan 06 '23 04:01 vgoklani

So I tried to add support for tied weights here https://github.com/karpathy/nanoGPT/commit/7c8288552b3673574e0649e031963b8e7e8d4981 , TLDR it's just one line

self.lm_head.weight = self.transformer.wte.weight # https://paperswithcode.com/method/weight-tying

and then some trickery with configure_optimizers, I triple checked everything and I think it could work, but torch.compile doesn't like it and starts spamming the following:

compiling the model... (takes a ~minute)
/home/ubuntu/miniconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py:44: UserWarning: functional_call was passed multiple values for tied weights. This behavior is deprecated and will be an error in future versions
  warnings.warn("functional_call was passed multiple values for tied weights. "
/home/ubuntu/miniconda3/envs/pytorch2/lib/python3.10/site-packages/torch/nn/utils/stateless.py:44: UserWarning: functional_call was passed multiple values for tied weights. This behavior is deprecated and will be an error in future versions
  warnings.warn("functional_call was passed multiple values for tied weights. "
...

which looks scary and is making me uncertain. So for now I'm keeping this commit on the tie_weights branch. @soumith I wasn't able to find many docs on this. Does torch compile not play nice with tied weights perhaps?

karpathy avatar Jan 14 '23 01:01 karpathy

ok i merged to master for now because things don't seem broken. will like to investigate the new warning produced separately to make sure everything is ok. Closing this PR as my implementation slightly departs from the proposal.

karpathy avatar Jan 14 '23 20:01 karpathy