Shape mismatch error when resizing token embeddings in OLMo modeling code
🐛 Describe the bug
I encountered a shape mismatch error when fine-tuning OLMo-1B using the Hugging Face Trainer after resizing the token embeddings. The error occurs due to an inconsistency between the embedding_size in the model configuration and the actual size of the token embeddings after resizing.
Steps to Reproduce:
(See below for reproducible example)
- Load an OLMo model and tokenizer with transformers library.
- Add special tokens to the tokenizer.
- Resize the token embeddings using
model.resize_token_embeddings(len(tokenizer)). - Fine-tune the model using the Hugging Face Trainer.
Error Message: RuntimeError: shape '[-1, 50304]' is invalid for input of size 102772320
Explanation:
The OLMo modeling code assumes that the embedding_size in the configuration (self.config.embedding_size) will always match the actual size of the token embeddings. However, when the token embeddings are resized using resize_token_embeddings, the embedding_size in the configuration remains unchanged, while the actual size of the token embeddings is updated. Note that resize_token_embeddings does update self.config.vocab_size but not self.config.embedding_size (here is the implementation in the transformers library).
In the forward pass, the code reshapes the shift_logits tensor using self.config.embedding_size:
shift_logits = shift_logits.view(-1, self.config.embedding_size) (link)
This leads to a shape mismatch error because the second dimension of shift_logits is expected to have a size of self.config.embedding_size (50304 in this case), but the actual size is the updated vocabulary size after resizing the token embeddings (e.g., 50282 after adding special tokens; because vocab size < embedding size, resizing according to the vocab size shrinks the embedding size).
Proposed fixes
I see a few different options for addressing this, and am happy to contribute.
- The OLMo modeling code could be updated to use the actual size of the token embeddings (
self.transformer.wte.weight.shape[0]) instead of relying on theembedding_sizevalue from the configuration when reshapingshift_logits. This will ensure that the code works correctly even when the token embeddings are resized dynamically. - Alternatively, a custom implementation of the
resize_token_embeddingsmethod that updatesself.config.embedding_sizeto match the actual size of the token embeddings could address this issue.
More broadly, I realize I did not have to resize the token embeddings in the first place as it is already considerably larger than the vocab, as noted in the OLMo paper. Before reading this and inspecting the model more closely, I found it surprising that expanding the vocabulary and then resizing the token embeddings based on the new vocab size ended up shrinking the token embeddings—perhaps a warning message noting the larger token embeddings would be useful here.
Reproducible Example
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import hf_olmo
model_ckpt = "allenai/OLMo-1B"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForCausalLM.from_pretrained(
model_ckpt,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
special_tokens = ["<|im_start|>", "<|im_end|>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
model.resize_token_embeddings(len(tokenizer))
# Create dummy input data
input_text = "This is a dummy input text."
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# Use the same input_ids as labels
labels = input_ids.clone()
# Move input_ids and labels to the same device as the model
input_ids = input_ids.to(model.device)
labels = labels.to(model.device)
# Call the model's forward pass directly
with torch.no_grad():
outputs = model(input_ids, labels=labels)
Thanks for looking!
Versions
Python 3.10.12 accelerate==0.27.2 ai2-olmo==0.2.5 bitsandbytes==0.43.0 huggingface-hub==0.21.4 numpy==1.23.5 packaging==22.0 pandas==1.5.3 torch==2.0.1+cu118 tqdm==4.64.1 transformers==4.38.2
Hi djliden, thanks for pointing the issue out and for proposing some fixes! The fix I would prefer is a custom implementation that updates the embedding size in self.config and also self.model.config.embedding_size (@AkshitaB feel free to chime in).
It seems like the Hugging Face vocab_size is supposed to correspond to OLMo's embedding_size, so I think that ideally we would also change our HF configs to remove their embedding_size and use OLMo's embedding_size as their vocab_size instead. We have several hundred checkpoints on HF, so unfortunately this does not seem feasible.
Regarding a warning message, I'm not sure how much value it adds. It could prevent someone from shrinking the embedding size unnecessarily, but I don't see much harm if someone does accidentally shrink it (so long as the size is bigger than vocab_size. On the other hand, it would be create noise for a person doing it knowingly.
I am happy to review a PR for a fix to this issue. Please also add @AkshitaB as a reviewer.
I'm closing this seeing that the fix for it has been merged. Please reopen if still actual.