LORA adapter with new embedding tokens has size and load issue
System Info
I was training google/gemma-3-270m by adding LORA adapters on attention layers and new embedding tokens. I am facing following issues,
- on using
model.save_pretrained, I observed the saved model has an abnormal size. This is surprisingly around the size of the complete model. A quick comparison is shown below wherein, if I train a complete embedding layer the saved adapter size is ~350 MB but if I add ~30 new tokens and only train those usingtrainable_token_indices, the size becomes ~675 MBs!
๐ Model Size Comparison:
Normal LoRA: 35.77 MB
Embedding LoRA (training complete embedding layers ): 355.77 MB
Embedding LoRA (training partial embedding layers): 100.12 MB
New Tokens LoRA (training newly added embedding layers): 675.96 MB
- saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?
Created a gist to test this behavior. The output of the script is as follows,
============================================================
MODEL SIZE TESTING SCRIPT
============================================================
๐ Attempting to load: google/gemma-3-270m
โ
Successfully loaded: google/gemma-3-270m
โ
Loaded model: google/gemma-3-270m
Vocabulary size: 262,145
๐งน Cleaning existing output directory: ./test_model_sizes
โ
Removed existing directory
============================================================
TEST 1: NORMAL LoRA (Attention Layers Only)
============================================================
============================================================
Setting up NORMAL LoRA (attention layers only)
============================================================
LoRA Configuration:
r (rank): 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: {'q_proj', 'v_proj'}
trainable params: 737,280 || all params: 268,835,456 || trainable%: 0.2742
๐พ Saving normal LoRA model to: ./test_model_sizes/normal_lora
============================================================
Model saved at: ./test_model_sizes/normal_lora
============================================================
๐ Model Statistics:
Total parameters: 268,835,456
Trainable parameters: 737,280
Vocabulary size: 262,145
๐พ Saved Model Size:
Total size: 35.77 MB
๐ Files in saved directory:
adapter_model.safetensors: 2.82 MB
tokenizer_config.json: 1.10 MB
special_tokens_map.json: 662.00 B
tokenizer.json: 31.84 MB
README.md: 5.07 KB
adapter_config.json: 854.00 B
๐ง Trainable Parameters Breakdown:
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
============================================================
============================================================
TEST 2: LoRA with EMBEDDING SPACE ADAPTER
============================================================
============================================================
Setting up LoRA with EMBEDDING SPACE ADAPTER (modules_to_save)
============================================================
LoRA Configuration:
r (rank): 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: {'q_proj', 'v_proj'}
modules_to_save: ['embed_tokens']
Original vocab size: 262,145
trainable params: 168,509,440 || all params: 436,607,616 || trainable%: 38.5952
๐พ Saving embedding LoRA model to: ./test_model_sizes/embedding_lora
============================================================
Model saved at: ./test_model_sizes/embedding_lora
============================================================
๐ Model Statistics:
Total parameters: 436,607,616
Trainable parameters: 168,509,440
Vocabulary size: 262,145
๐พ Saved Model Size:
Total size: 355.77 MB
๐ Files in saved directory:
adapter_model.safetensors: 322.82 MB
tokenizer_config.json: 1.10 MB
special_tokens_map.json: 662.00 B
tokenizer.json: 31.84 MB
README.md: 5.07 KB
adapter_config.json: 874.00 B
๐ง Trainable Parameters Breakdown:
base_model.model.model.embed_tokens.modules_to_save.default.weight: torch.Size([262144, 640]) (167,772,160 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
============================================================
============================================================
TEST 3: LoRA with TRAINABLE TOKEN INDICES
============================================================
============================================================
Setting up LoRA with EMBEDDING SPACE ADAPTER (trainable_token_indices)
============================================================
Vocabulary size: 262,145
Training tokens from index 235930 to 262144
Number of trainable token indices: 26,214
LoRA Configuration:
r (rank): 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: {'q_proj', 'v_proj'}
trainable_token_indices: 26,214 tokens
(indices 235930 to 262144)
trainable params: 17,514,240 || all params: 285,612,416 || trainable%: 6.1322
๐พ Saving trainable indices LoRA model to: ./test_model_sizes/trainable_indices_lora
============================================================
Model saved at: ./test_model_sizes/trainable_indices_lora
============================================================
๐ Model Statistics:
Total parameters: 285,612,416
Trainable parameters: 17,514,240
Vocabulary size: 262,145
๐พ Saved Model Size:
Total size: 100.12 MB
๐ Files in saved directory:
adapter_model.safetensors: 66.82 MB
tokenizer_config.json: 1.10 MB
special_tokens_map.json: 662.00 B
tokenizer.json: 31.84 MB
README.md: 5.07 KB
adapter_config.json: 359.26 KB
๐ง Trainable Parameters Breakdown:
base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta.default: torch.Size([26214, 640]) (16,776,960 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
============================================================
============================================================
TEST 4: LoRA with NEW TOKENS ADDED
============================================================
============================================================
Setting up LoRA with NEW TOKENS (train only new tokens)
============================================================
Original vocabulary size: 262,145
Added 24 new tokens to tokenizer
New vocabulary size: 262,169
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
Resized model embeddings to 262,169 tokens
New token indices: 262145 to 262168
Number of new token indices: 24
LoRA Configuration:
r (rank): 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: {'q_proj', 'v_proj'}
trainable_token_indices: 24 new tokens
(indices 262145 to 262168)
trainable params: 752,640 || all params: 268,866,816 || trainable%: 0.2799
๐พ Saving new tokens LoRA model to: ./test_model_sizes/new_tokens_lora
/Users/.../Work/tts/gemma3_audio_codecs/.venv/lib/python3.9/site-packages/peft/utils/save_and_load.py:300: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
warnings.warn(
============================================================
Model saved at: ./test_model_sizes/new_tokens_lora
============================================================
๐ Model Statistics:
Total parameters: 268,866,816
Trainable parameters: 752,640
Vocabulary size: 262,169
๐พ Saved Model Size:
Total size: 675.96 MB
๐ Files in saved directory:
adapter_model.safetensors: 643.00 MB
tokenizer_config.json: 1.11 MB
special_tokens_map.json: 662.00 B
tokenizer.json: 31.84 MB
README.md: 5.07 KB
adapter_config.json: 1.19 KB
๐ง Trainable Parameters Breakdown:
base_model.model.model.embed_tokens.token_adapter.trainable_tokens_delta.default: torch.Size([24, 640]) (15,360 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight: torch.Size([1024, 16]) (16,384 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight: torch.Size([16, 640]) (10,240 params)
base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight: torch.Size([256, 16]) (4,096 params)
============================================================
============================================================
SIZE COMPARISON
============================================================
๐ Model Size Comparison:
Normal LoRA: 35.77 MB
Embedding LoRA (modules_to_save): 355.77 MB
Embedding LoRA (trainable_indices): 100.12 MB
New Tokens LoRA: 675.96 MB
๐ Size Ranking (smallest to largest):
1. Normal LoRA: 35.77 MB
2. Embedding LoRA (trainable_indices): 100.12 MB
3. Embedding LoRA (modules_to_save): 355.77 MB
4. New Tokens LoRA: 675.96 MB
Embedding LoRA (modules_to_save) is 9.95x larger than Normal LoRA
Embedding LoRA (trainable_indices) is 2.80x larger than Normal LoRA
New Tokens LoRA is 18.90x larger than Normal LoRA
Embedding LoRA (modules_to_save) is 3.55x larger than trainable_indices
New Tokens LoRA is 6.75x larger than trainable_indices
============================================================
TESTING COMPLETE
============================================================
Who can help?
No response
Reproduction
https://gist.github.com/imohitmayank/cead7ad4a63c8770bbd5a8f48d25aeeb
Expected behavior
Imo the saved model adapter should only contain the trainable model, and it should be easily loadable.
As a general remark: Gemma is unusual in that it uses a relatively big vocabulary. Therefore, the embedding layer is particularly large. This is especially noticeable with the smaller Gemma models. For "google/gemma-3-270m", the embedding layer makes up 63% of all parameters. This is why the checkpoint size for this model will vary substantially based on whether the embedding is involved or not. With other models that have smaller embeddings, the difference will not be so extreme.
Now for the different cases you present:
1. Normal LoRA
I think everything works as expected, no explanation needed.
2. Embedding LoRA (modules_to_save)
This means you want to fully fine-tune the embedding. Therefore, it is saved as part of the checkpoint. Since it's so big, it explains the large checkpoint size:
size = num params of embedding * size of dtype in bytes (bfloat16)
= 167772160 * 2
= 335544320
= 320 MB
3. Embedding LoRA (trainable_indices)
There, you want to train 10% of the vocab, so this results in:
size = num params of delta * size of dtype in bytes (float32)
= 16777216 * 4
= 67107840
= 64 MB
The remaining size of this directory is made up of 33 MB for the tokenizer.json.
4. New Tokens LoRA
Here you are extending the vocabulary of the base model. Therefore, by default, PEFT will save the base model embedding as part of the checkpoint (bfloat16, 320 MB). Since the LM head is tied to the embedding, it is also stored (another 320 MB). If you don't want the embedding layer to be saved, pass save_embedding_layers=False when calling save_pretrained. This reduces the checkpoint size to 3 MB.
If you don't save the embedding but still extend the vocabulary, ensure that for inference, you can still restore the previous state (e.g. via fixed random seed). Otherwise, the underlying embeddings will be different and the model will produce non-sense for them.
Regarding tied weights, we are currently working on improving the situation.
saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?
Yes, when you extend the vocabulary before training, when you load the model, you have to perform the same steps again. This is the intended way.
@imohitmayank
I would also suggest to add ensure_weight_tying flag as True in LoraConfig if you add the embedding layer in modules_to_save. This would keep the weight tying consistent and mark embedding layer as trainable. And the adapter loading would not break. It would also keep the size to 320MB.
Based on the analysis by @BenjaminBossan and you, this seems the most straightforward solve for your use case.
As a general remark: Gemma is unusual in that it uses a relatively big vocabulary. Therefore, the embedding layer is particularly large. This is especially noticeable with the smaller Gemma models. For
"google/gemma-3-270m", the embedding layer makes up 63% of all parameters. This is why the checkpoint size for this model will vary substantially based on whether the embedding is involved or not. With other models that have smaller embeddings, the difference will not be so extreme.Now for the different cases you present:
- Normal LoRA
I think everything works as expected, no explanation needed.
- Embedding LoRA (modules_to_save)
This means you want to fully fine-tune the embedding. Therefore, it is saved as part of the checkpoint. Since it's so big, it explains the large checkpoint size:
size = num params of embedding * size of dtype in bytes (bfloat16) = 167772160 * 2 = 335544320 = 320 MB
- Embedding LoRA (trainable_indices)
There, you want to train 10% of the vocab, so this results in:
size = num params of delta * size of dtype in bytes (float32) = 16777216 * 4 = 67107840 = 64 MBThe remaining size of this directory is made up of 33 MB for the
tokenizer.json.
- New Tokens LoRA
Here you are extending the vocabulary of the base model. Therefore, by default, PEFT will save the base model embedding as part of the checkpoint (bfloat16, 320 MB). Since the LM head is tied to the embedding, it is also stored (another 320 MB). If you don't want the embedding layer to be saved, pass
save_embedding_layers=Falsewhen callingsave_pretrained. This reduces the checkpoint size to 3 MB.If you don't save the embedding but still extend the vocabulary, ensure that for inference, you can still restore the previous state (e.g. via fixed random seed). Otherwise, the underlying embeddings will be different and the model will produce non-sense for them.
Regarding tied weights, we are currently working on improving the situation.
saved model cannot be loaded back, as it is throwing error wrt newly added token not being accounted for. I have to load the base model, then resize the model by adding new tokens again and then load the adapter. Is this the expected behavior?
Yes, when you extend the vocabulary before training, when you load the model, you have to perform the same steps again. This is the intended way.
Thanks @BenjaminBossan, for the detailed explanation. This clarifies the current implementation. Looking forward to the weight tying release.
@imohitmayank I would also suggest to add
ensure_weight_tyingflag as True inLoraConfigif you add the embedding layer inmodules_to_save. This would keep the weight tying consistent and mark embedding layer as trainable. And the adapter loading would not break. It would also keep the size to 320MB.Based on the analysis by @BenjaminBossan and you, this seems the most straightforward solve for your use case.
Hey @romitjain thanks for getting back. I tried setting the ensure_weight_tying flag as True, but still getting the size ~670MB! I have modified the original gist with a case 6 to test this flow, basically added the test script to add new tokens and turn the flag on. Request you to have a look when possible, in case I am missing something.
Thanks.
@imohitmayank Can you try ensure_weight_tying flag with modules_to_save?
Instead of passing trainable_tokens, can you please try passing embed_tokens layer to modules_to_save?
@imohitmayank Can you try
ensure_weight_tyingflag withmodules_to_save? Instead of passingtrainable_tokens, can you please try passingembed_tokenslayer tomodules_to_save?
yeah, tried that as well, still getting ~670MB sized model.
lora_config = LoraConfig(
r=16, # LoRA attention dimension
lora_alpha=32, # Alpha parameter for LoRA scaling
lora_dropout=0.05, # Dropout probability for LoRA layers
bias="none", # Bias type for LoRA
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "v_proj"], # Still target attention layers
# trainable_token_indices={"embed_tokens": new_token_indices}, # Train only new tokens
modules_to_save=["embed_tokens"], # Train entire embedding layer (including new tokens)
ensure_weight_tying=True, # Keep weight tying consistent and mark embedding layer as trainable
)
@imohitmayank Yes, you are correct. I am not sure what should be done here from PEFT side. @BenjaminBossan would be the correct person for that.
But as far as we go for correctness, if you want to use modules_to_save with an embedding layer, be sure to add the ensure_weight_tying flag. You can, of course, iterate over the adapter_model.safetensor and remove the lm_head layer manually. Since the lm_head layer is tied with embed_tokens, you would not be losing anything (for the case I mentioned)
For the question whether the embedding is saved as part of the checkpoint, setting ensure_weight_tying makes no difference.
Note that if you have modules_to_save=["embed_tokens"], it is required to save the embedding, since this setting implies that we want to fully fine-tune it. We could think about deduping the weight if ensure_weight_tying=True, but such a logic is not implemented right now.