Medusa icon indicating copy to clipboard operation
Medusa copied to clipboard

Why medusa-2 train llama2 with no such great improvement?

Open MeJerry215 opened this issue 1 year ago • 2 comments

In the given examples axoltol exmaples/medusa, I follow the vicuna_7b_qlora_stage1.yml and vicuna_7b_qlora_stage2.yml to write my llama2 trainning config.

Howerver I did't get such greate performance improvement, below is my test results with different tokens generation in A100 image

Then I use official vicuna-7b medusa2 weigths, it does work. image

So here I want to know what's the difference ? Is my training config fault.

MeJerry215 avatar Feb 26 '24 11:02 MeJerry215

@ctlllll

and here is my training config llama2_7b_stage1.yml

base_model: Llama-2-7b-hf
base_model_config: Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: ./ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json
    type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Llama2_7b_qlora_stage1

adapter: qlora
lora_model_dir:

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
  - lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
medusa_only_heads: true
ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model

llama2_7b_stage2.yml.

base_model: Llama-2-7b-hf
base_model_config: Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json
    type: sharegpt
dataset_prepared_path:
val_set_size: 0.01
output_dir: ./Llama2_7b_qlora_stage2

adapter: qlora
lora_model_dir:

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj
  - lm_head
lora_target_linear:
lora_fan_in_fan_out:
lora_modules_to_save:
lora_model_dir: ./Llama2_7b_qlora_stage1

sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 2
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0005

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 40
eval_steps: 40
save_steps:
save_total_limit: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

medusa_num_heads: 5
medusa_num_layers: 1
medusa_heads_coefficient: 0.2
medusa_decay_coefficient: 0.8
medusa_logging: true
medusa_scheduler: constant
medusa_lr_multiplier: 4.0
# medusa_only_heads: true
# ddp_find_unused_parameters: true
# Stage 1: only train the medusa heads
# Stage 2: train the whole model

Thanks a lot.

Also I saw the checkpoints you provid, there is no any lora adapter weight, Is It just remove all the lora config?

MeJerry215 avatar Feb 26 '24 11:02 MeJerry215

same, I couldn't reproduce medusa-1 as well. Not sure if you have solved the issue. My guess is the training dataset was not cleaned (or simply not the one used by Vicuna model).

chenhan97 avatar Mar 27 '24 03:03 chenhan97