peft icon indicating copy to clipboard operation
peft copied to clipboard

LORA adapter with new embedding tokens has size and load issue

Open imohitmayank opened this issue 1 month ago โ€ข 7 comments

System Info

I was training google/gemma-3-270m by adding LORA adapters on attention layers and new embedding tokens. I am facing following issues,

  • on using model.save_pretrained, I observed the saved model has an abnormal size. This is surprisingly around the size of the complete model. A quick comparison is shown below wherein, if I train a complete embedding layer the saved adapter size is ~350 MB but if I add ~30 new tokens and only train those using trainable_token_indices, the size becomes ~675 MBs!
๐Ÿ“Š Model Size Comparison:
  Normal LoRA:                    35.77 MB
  Embedding LoRA (training complete embedding layers ): 355.77 MB
  Embedding LoRA (training partial embedding layers): 100.12 MB
  New Tokens LoRA (training newly added embedding layers): 675.96 MB
  • saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?

Created a gist to test this behavior. The output of the script is as follows,


============================================================
MODEL SIZE TESTING SCRIPT
============================================================

๐Ÿ” Attempting to load: google/gemma-3-270m
โœ… Successfully loaded: google/gemma-3-270m

โœ… Loaded model: google/gemma-3-270m
   Vocabulary size: 262,145

๐Ÿงน Cleaning existing output directory: ./test_model_sizes
โœ… Removed existing directory

============================================================
TEST 1: NORMAL LoRA (Attention Layers Only)
============================================================

============================================================
Setting up NORMAL LoRA (attention layers only)
============================================================

LoRA Configuration:
  r (rank): 16
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules: {'q_proj', 'v_proj'}
trainable params: 737,280 || all params: 268,835,456 || trainable%: 0.2742

๐Ÿ’พ Saving normal LoRA model to: ./test_model_sizes/normal_lora

============================================================
Model saved at: ./test_model_sizes/normal_lora
============================================================

๐Ÿ“Š Model Statistics:
  Total parameters: 268,835,456
  Trainable parameters: 737,280
  Vocabulary size: 262,145

๐Ÿ’พ Saved Model Size:
  Total size: 35.77 MB

๐Ÿ“ Files in saved directory:
    adapter_model.safetensors: 2.82 MB
    tokenizer_config.json: 1.10 MB
    special_tokens_map.json: 662.00 B
    tokenizer.json: 31.84 MB
    README.md: 5.07 KB
    adapter_config.json: 854.00 B

๐Ÿ”ง Trainable Parameters Breakdown:
    base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)

============================================================

============================================================
TEST 2: LoRA with EMBEDDING SPACE ADAPTER
============================================================

============================================================
Setting up LoRA with EMBEDDING SPACE ADAPTER (modules_to_save)
============================================================

LoRA Configuration:
  r (rank): 16
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules: {'q_proj', 'v_proj'}
  modules_to_save: ['embed_tokens']
  Original vocab size: 262,145
trainable params: 168,509,440 || all params: 436,607,616 || trainable%: 38.5952

๐Ÿ’พ Saving embedding LoRA model to: ./test_model_sizes/embedding_lora

============================================================
Model saved at: ./test_model_sizes/embedding_lora
============================================================

๐Ÿ“Š Model Statistics:
  Total parameters: 436,607,616
  Trainable parameters: 168,509,440
  Vocabulary size: 262,145

๐Ÿ’พ Saved Model Size:
  Total size: 355.77 MB

๐Ÿ“ Files in saved directory:
    adapter_model.safetensors: 322.82 MB
    tokenizer_config.json: 1.10 MB
    special_tokens_map.json: 662.00 B
    tokenizer.json: 31.84 MB
    README.md: 5.07 KB
    adapter_config.json: 874.00 B

๐Ÿ”ง Trainable Parameters Breakdown:
    base_model.model.model.embed_tokens.modules_to_save.default.weight: torch.Size([262144, 640]) (167,772,160 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)

============================================================

============================================================
TEST 3: LoRA with TRAINABLE TOKEN INDICES
============================================================

============================================================
Setting up LoRA with EMBEDDING SPACE ADAPTER (trainable_token_indices)
============================================================

  Vocabulary size: 262,145
  Training tokens from index 235930 to 262144
  Number of trainable token indices: 26,214

LoRA Configuration:
  r (rank): 16
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules: {'q_proj', 'v_proj'}
  trainable_token_indices: 26,214 tokens
    (indices 235930 to 262144)
trainable params: 17,514,240 || all params: 285,612,416 || trainable%: 6.1322

๐Ÿ’พ Saving trainable indices LoRA model to: ./test_model_sizes/trainable_indices_lora

============================================================
Model saved at: ./test_model_sizes/trainable_indices_lora
============================================================

๐Ÿ“Š Model Statistics:
  Total parameters: 285,612,416
  Trainable parameters: 17,514,240
  Vocabulary size: 262,145

๐Ÿ’พ Saved Model Size:
  Total size: 100.12 MB

๐Ÿ“ Files in saved directory:
    adapter_model.safetensors: 66.82 MB
    tokenizer_config.json: 1.10 MB
    special_tokens_map.json: 662.00 B
    tokenizer.json: 31.84 MB
    README.md: 5.07 KB
    adapter_config.json: 359.26 KB

๐Ÿ”ง Trainable Parameters Breakdown:
    base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta.default: torch.Size([26214, 640]) (16,776,960 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)

============================================================

============================================================
TEST 4: LoRA with NEW TOKENS ADDED
============================================================

============================================================
Setting up LoRA with NEW TOKENS (train only new tokens)
============================================================

  Original vocabulary size: 262,145
  Added 24 new tokens to tokenizer
  New vocabulary size: 262,169
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  Resized model embeddings to 262,169 tokens
  New token indices: 262145 to 262168
  Number of new token indices: 24

LoRA Configuration:
  r (rank): 16
  lora_alpha: 32
  lora_dropout: 0.05
  target_modules: {'q_proj', 'v_proj'}
  trainable_token_indices: 24 new tokens
    (indices 262145 to 262168)
trainable params: 752,640 || all params: 268,866,816 || trainable%: 0.2799

๐Ÿ’พ Saving new tokens LoRA model to: ./test_model_sizes/new_tokens_lora
/Users/.../Work/tts/gemma3_audio_codecs/.venv/lib/python3.9/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(

============================================================
Model saved at: ./test_model_sizes/new_tokens_lora
============================================================

๐Ÿ“Š Model Statistics:
  Total parameters: 268,866,816
  Trainable parameters: 752,640
  Vocabulary size: 262,169

๐Ÿ’พ Saved Model Size:
  Total size: 675.96 MB

๐Ÿ“ Files in saved directory:
    adapter_model.safetensors: 643.00 MB
    tokenizer_config.json: 1.11 MB
    special_tokens_map.json: 662.00 B
    tokenizer.json: 31.84 MB
    README.md: 5.07 KB
    adapter_config.json: 1.19 KB

๐Ÿ”ง Trainable Parameters Breakdown:
    base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta.default: torch.Size([24, 640]) (15,360 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
    base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)

============================================================

============================================================
SIZE COMPARISON
============================================================

๐Ÿ“Š Model Size Comparison:
  Normal LoRA:                    35.77 MB
  Embedding LoRA (modules_to_save): 355.77 MB
  Embedding LoRA (trainable_indices): 100.12 MB
  New Tokens LoRA:                 675.96 MB

๐Ÿ“ˆ Size Ranking (smallest to largest):
  1. Normal LoRA: 35.77 MB
  2. Embedding LoRA (trainable_indices): 100.12 MB
  3. Embedding LoRA (modules_to_save): 355.77 MB
  4. New Tokens LoRA: 675.96 MB

  Embedding LoRA (modules_to_save) is 9.95x larger than Normal LoRA
  Embedding LoRA (trainable_indices) is 2.80x larger than Normal LoRA
  New Tokens LoRA is 18.90x larger than Normal LoRA
  Embedding LoRA (modules_to_save) is 3.55x larger than trainable_indices
  New Tokens LoRA is 6.75x larger than trainable_indices

============================================================
TESTING COMPLETE
============================================================

Who can help?

No response

Reproduction

https://gist.github.com/imohitmayank/cead7ad4a63c8770bbd5a8f48d25aeeb

Expected behavior

Imo the saved model adapter should only contain the trainable model, and it should be easily loadable.

imohitmayank avatar Nov 05 '25 15:11 imohitmayank

As a general remark: Gemma is unusual in that it uses a relatively big vocabulary. Therefore, the embedding layer is particularly large. This is especially noticeable with the smaller Gemma models. For "google/gemma-3-270m", the embedding layer makes up 63% of all parameters. This is why the checkpoint size for this model will vary substantially based on whether the embedding is involved or not. With other models that have smaller embeddings, the difference will not be so extreme.

Now for the different cases you present:

1. Normal LoRA

I think everything works as expected, no explanation needed.

2. Embedding LoRA (modules_to_save)

This means you want to fully fine-tune the embedding. Therefore, it is saved as part of the checkpoint. Since it's so big, it explains the large checkpoint size:

size = num params of embedding * size of dtype in bytes (bfloat16)
= 167772160 * 2
= 335544320
= 320 MB

3. Embedding LoRA (trainable_indices)

There, you want to train 10% of the vocab, so this results in:

size = num params of delta * size of dtype in bytes (float32)
= 16777216 * 4
= 67107840
= 64 MB

The remaining size of this directory is made up of 33 MB for the tokenizer.json.

4. New Tokens LoRA

Here you are extending the vocabulary of the base model. Therefore, by default, PEFT will save the base model embedding as part of the checkpoint (bfloat16, 320 MB). Since the LM head is tied to the embedding, it is also stored (another 320 MB). If you don't want the embedding layer to be saved, pass save_embedding_layers=False when calling save_pretrained. This reduces the checkpoint size to 3 MB.

If you don't save the embedding but still extend the vocabulary, ensure that for inference, you can still restore the previous state (e.g. via fixed random seed). Otherwise, the underlying embeddings will be different and the model will produce non-sense for them.

Regarding tied weights, we are currently working on improving the situation.

saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?

Yes, when you extend the vocabulary before training, when you load the model, you have to perform the same steps again. This is the intended way.

BenjaminBossan avatar Nov 06 '25 13:11 BenjaminBossan

@imohitmayank I would also suggest to add ensure_weight_tying flag as True in LoraConfig if you add the embedding layer in modules_to_save. This would keep the weight tying consistent and mark embedding layer as trainable. And the adapter loading would not break. It would also keep the size to 320MB.

Based on the analysis by @BenjaminBossan and you, this seems the most straightforward solve for your use case.

romitjain avatar Nov 12 '25 08:11 romitjain

As a general remark: Gemma is unusual in that it uses a relatively big vocabulary. Therefore, the embedding layer is particularly large. This is especially noticeable with the smaller Gemma models. For "google/gemma-3-270m", the embedding layer makes up 63% of all parameters. This is why the checkpoint size for this model will vary substantially based on whether the embedding is involved or not. With other models that have smaller embeddings, the difference will not be so extreme.

Now for the different cases you present:

  1. Normal LoRA

I think everything works as expected, no explanation needed.

  1. Embedding LoRA (modules_to_save)

This means you want to fully fine-tune the embedding. Therefore, it is saved as part of the checkpoint. Since it's so big, it explains the large checkpoint size:

size = num params of embedding * size of dtype in bytes (bfloat16)
= 167772160 * 2
= 335544320
= 320 MB
  1. Embedding LoRA (trainable_indices)

There, you want to train 10% of the vocab, so this results in:

size = num params of delta * size of dtype in bytes (float32)
= 16777216 * 4
= 67107840
= 64 MB

The remaining size of this directory is made up of 33 MB for the tokenizer.json.

  1. New Tokens LoRA

Here you are extending the vocabulary of the base model. Therefore, by default, PEFT will save the base model embedding as part of the checkpoint (bfloat16, 320 MB). Since the LM head is tied to the embedding, it is also stored (another 320 MB). If you don't want the embedding layer to be saved, pass save_embedding_layers=False when calling save_pretrained. This reduces the checkpoint size to 3 MB.

If you don't save the embedding but still extend the vocabulary, ensure that for inference, you can still restore the previous state (e.g. via fixed random seed). Otherwise, the underlying embeddings will be different and the model will produce non-sense for them.

Regarding tied weights, we are currently working on improving the situation.

saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?

Yes, when you extend the vocabulary before training, when you load the model, you have to perform the same steps again. This is the intended way.

Thanks @BenjaminBossan, for the detailed explanation. This clarifies the current implementation. Looking forward to the weight tying release.

imohitmayank avatar Nov 13 '25 06:11 imohitmayank

@imohitmayank I would also suggest to add ensure_weight_tying flag as True in LoraConfig if you add the embedding layer in modules_to_save. This would keep the weight tying consistent and mark embedding layer as trainable. And the adapter loading would not break. It would also keep the size to 320MB.

Based on the analysis by @BenjaminBossan and you, this seems the most straightforward solve for your use case.

Hey @romitjain thanks for getting back. I tried setting the ensure_weight_tying flag as True, but still getting the size ~670MB! I have modified the original gist with a case 6 to test this flow, basically added the test script to add new tokens and turn the flag on. Request you to have a look when possible, in case I am missing something. Thanks.

imohitmayank avatar Nov 13 '25 06:11 imohitmayank

@imohitmayank Can you try ensure_weight_tying flag with modules_to_save? Instead of passing trainable_tokens, can you please try passing embed_tokens layer to modules_to_save?

romitjain avatar Nov 13 '25 08:11 romitjain

@imohitmayank Can you try ensure_weight_tying flag with modules_to_save? Instead of passing trainable_tokens, can you please try passing embed_tokens layer to modules_to_save?

yeah, tried that as well, still getting ~670MB sized model.

lora_config = LoraConfig(
        r=16,  # LoRA attention dimension
        lora_alpha=32,  # Alpha parameter for LoRA scaling
        lora_dropout=0.05,  # Dropout probability for LoRA layers
        bias="none",  # Bias type for LoRA
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "v_proj"],  # Still target attention layers
        # trainable_token_indices={"embed_tokens": new_token_indices},  # Train only new tokens
        modules_to_save=["embed_tokens"],  # Train entire embedding layer (including new tokens)
        ensure_weight_tying=True,  # Keep weight tying consistent and mark embedding layer as trainable
    )

imohitmayank avatar Nov 14 '25 11:11 imohitmayank

@imohitmayank Yes, you are correct. I am not sure what should be done here from PEFT side. @BenjaminBossan would be the correct person for that.

But as far as we go for correctness, if you want to use modules_to_save with an embedding layer, be sure to add the ensure_weight_tying flag. You can, of course, iterate over the adapter_model.safetensor and remove the lm_head layer manually. Since the lm_head layer is tied with embed_tokens, you would not be losing anything (for the case I mentioned)

romitjain avatar Nov 14 '25 16:11 romitjain

For the question whether the embedding is saved as part of the checkpoint, setting ensure_weight_tying makes no difference.

Note that if you have modules_to_save=["embed_tokens"], it is required to save the embedding, since this setting implies that we want to fully fine-tune it. We could think about deduping the weight if ensure_weight_tying=True, but such a logic is not implemented right now.

BenjaminBossan avatar Nov 17 '25 14:11 BenjaminBossan