unsloth icon indicating copy to clipboard operation
unsloth copied to clipboard

Add support for loading checkpoints with newly added tokens.

Open charlesCXK opened this issue 11 months ago • 6 comments

charlesCXK avatar Mar 22 '24 13:03 charlesCXK

Wait would this load the lm_head and embed_tokens matrix correctly?

danielhanchen avatar Mar 23 '24 17:03 danielhanchen

Would it not cause it to be randomnly inited?

danielhanchen avatar Mar 23 '24 17:03 danielhanchen

Would it not cause it to be randomnly inited?

I have tested the code using such a setting:

  1. First I add new tokens to the tokenizer.
'''
########################################
Add special tokens to the tokenizer.
########################################
'''
if True:
    old_vocab_size = tokenizer.vocab_size
    print('old vocab size: ', old_vocab_size)
    tokenizer.add_tokens("<NEWTOKEN>", special_tokens=True)
    tokenizer.add_tokens("</NEWTOKEN>", special_tokens=True)

    # test case
    print(tokenizer.tokenize("This is an example with <NEWTOKEN> and </NEWTOKEN> token."))  

    # We resize the embeddings to avoid index errors.
    model.resize_token_embeddings(len(tokenizer))
    model.config.vocab_size = len(tokenizer)

    # average init the new token embeddings
    num_new_tokens = len(tokenizer) - old_vocab_size
    print("num_new_tokens:", num_new_tokens)
    input_embeddings = model.get_input_embeddings().weight.data
    output_embeddings = model.get_output_embeddings().weight.data
    input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
        dim=0, keepdim=True)
    output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
        dim=0, keepdim=True)
    input_embeddings[-num_new_tokens:] = input_embeddings_avg
    output_embeddings[-num_new_tokens:] = output_embeddings_avg

    # open lm head and input embedding
    model.lm_head.weight.requires_grad = True
    model.get_input_embeddings().weight.requires_grad = True
  1. I trained the model on a dataset with several steps and save the lora checkpoint.
save_path = "/home/xxx"
if os.path.exists(save_path):
    shutil.rmtree(save_path)
model.save_pretrained(save_path)
  1. Then I use the saved checkpoint for inference.
print('Use saved model for inference.')
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    new_token_num = 0,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[
    "Continue the fibonnaci sequence. 1, 1, 2, 3, 5, 8"
], return_tensors = "pt").to("cuda")

from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)
  1. The output is the same as the original model.

charlesCXK avatar Mar 24 '24 10:03 charlesCXK

Hi @charlesCXK, when using this code, I noticed that the loaded model doesn't include the new token that I added before fine-tuning. Do you have to add the new token again for inference? For example,

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = save_path, # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    new_token_num = 1,        # 1 new added token
)
if "<pad>" not in tokenizer.get_vocab():
    tokenizer.add_tokens(["<pad>"], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))  

# Inference code goes here

chtmp223 avatar Apr 04 '24 22:04 chtmp223

Whoopsies sorry on the horrible delay - I'll review this PR and test it out - so sorry!

danielhanchen avatar Apr 05 '24 17:04 danielhanchen

@charlesCXK @chtmp223 Extreme apologies on the delay - I think I might have fixed it. You need to call add_new_tokens before get_peft_model to update the vocab, resize, and also save the learnt embeddings

from unsloth import add_new_tokens
from unsloth import FastLanguageModel

add_new_tokens(model, tokenizer, ["new_token_1", "new_token_2"])
model = FastLanguageModel.get_peft_model(model, ...)

danielhanchen avatar Apr 21 '24 19:04 danielhanchen