sub-par tuned model quality compared to LlamaFactory
Please check that this issue hasn't been reported before.
- [x] I searched previous Bug Reports didn't find any similar reports.
Expected Behavior
I used allenai/tulu-3-sft-mixture dataset to tune pretrained Llama3.1-8b model for instruction following.
I tested both Axolotl and LlamaFactory. Using the same hyperparams, after tuning, the best saved checkpoint from Axolotl consistently under performs Llama Factory by a large margin on the IFeval metric.
LlamaFactory:
|ifeval| 4|none | 0|inst_level_loose_acc |↑ |0.7974|± | N/A|
| | |none | 0|inst_level_strict_acc |↑ |0.7698|± | N/A|
| | |none | 0|prompt_level_loose_acc |↑ |0.7246|± |0.0192|
| | |none | 0|prompt_level_strict_acc|↑ |0.6987|± |0.0197|
Axolotl:
|Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval| 4|none | 0|inst_level_loose_acc |↑ |0.6403|± | N/A|
| | |none | 0|inst_level_strict_acc |↑ |0.6151|± | N/A|
| | |none | 0|prompt_level_loose_acc |↑ |0.5471|± |0.0214|
| | |none | 0|prompt_level_strict_acc|↑ |0.5176|± |0.0215|
I expect some diffs between different frameworks, but did not expect the gap to be this large
Current behaviour
I'd expect the tuned model's quality to be similar to the ones from LlamaFactory
Steps to reproduce
axolotl train <config-file>
Config yaml
base_model: meta-llama/Meta-Llama-3-8B
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
chat_template: jinja
chat_template_jinja: <path/to/tulu3.jinja>
datasets:
- path: allenai/tulu-3-sft-mixture
split: train
type: chat_template
field_messages: messages
dataset_prepared_path: last_run_prepared
dataloader_num_workers: 128
dataloader_pin_memory: true
dataset_processes: 128
val_set_size: 0.05
output_dir: sft-llama3-no-packing-v1
sequence_len: 4096
sample_packing: false
flash_attention: true
gradient_checkpointing: false
wandb_project:
wandb_entity:
wandb_name:
gradient_accumulation_steps: 8
micro_batch_size: 2
eval_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
learning_rate: 5e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
max_grad_norm: 1.0
lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
warmup_steps: 500
cosine_constant_lr_ratio: 0.0
bf16: true
tf32: true
logging_steps: 1
eval_strategy: "steps"
eval_steps: 1000
save_strategy: "best"
metric_for_best_model: "eval_loss"
greater_is_better: false
save_total_limit: 3
resume_from_checkpoint:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: SHARDED_STATE_DICT
reshard_after_forward: true
tulu3.jinja
{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n' + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}
Possible solution
No response
Which Operating Systems are you using?
- [ ] Linux
- [ ] macOS
- [ ] Windows
Python Version
3.11.13
axolotl branch-commit
main/050210e637a7ca2fdb65491eced13bd4d1ce5d10
Acknowledgements
- [x] My issue title is concise, descriptive, and in title casing.
- [x] I have searched the existing issues to make sure this bug has not been reported yet.
- [x] I am using the latest version of axolotl.
- [x] I have provided enough information for the maintainers to reproduce and diagnose the issue.
That's a very interesting observation. We both use transformers under the hood (just different speed/vram optimizations), so I'm surprised the differences are this much.
Could you provide a repro for running on llama factory too?
thanks @NanoCode012 Here's an example llama-factory config we used.
### model
model_name_or_path: meta-llama/Llama-3.1-8B
trust_remote_code: true
### method
stage: sft
do_train: true
finetuning_type: full
# deepspeed: examples/deepspeed/ds_z2_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
include_effective_tokens_per_second: true
flash_attn: fa2
enable_liger_kernel: true
packing: false
include_num_input_tokens_seen: true
### dataset
dataset: tulu_3_sft_mixture
template: "{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n' + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}"
cutoff_len: 4096
# max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 16
### output
output_dir: saves/llama3p1-8b/full/fa2_liger_fsdp_chat_template
logging_steps: 10
save_steps: 5000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: wandb # choices: [none, wandb, tensorboard, swanlab, mlflow]
run_name: llama3p1-8b-sft-fa2_liger_fsdp
### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 5.0e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
num_train_epochs: 4.0
lr_scheduler_type: cosine
warmup_ratio: 0.03
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null
### eval
# eval_dataset: alpaca_en_demo
val_size: 0.05
per_device_eval_batch_size: 4
eval_strategy: steps
eval_steps: 500
meta-llama/Meta-Llama-3-8B and meta-llama/Llama-3.1-8B are different models. Unless that was a typo?
@xzuyn That's a typo when I cleaning up the Axolotl config. Sorry about that. I was using a local path with llama3.1-8b model with Axolotl. So both experiments were using the same llama3.1-8b as base model.
Thanks for the repro. We will look into this
heyy i tried to repro this on script on llama factory do i need to register the chat template manually ? also the regarding dataset i need to add it to dataset.json file can you also share the functions and everything ?
@ved1beta Thanks for trying
do i need to register the chat template manually
You can add the template to the tokenizer_config.json
dataset.json file can you also share the functions
This is what's added
"tulu_3_sft_mixture": {
"hf_hub_url": "allenai/tulu-3-sft-mixture",
"formatting": "sharegpt",
"columns": {
"messages": "messages"
},
"tags": {
"role_tag": "role",
"content_tag": "content",
"user_tag": "user",
"assistant_tag": "assistant",
"system_tag": "system"
}
},
@weiran-work How GPUs for FSDP? H100? This affects reproducibility on our end since it affects effective batch size. thanks!
@winglian We used 8xH100, effective batch size is 128
@weiran-work I think the problem lies in the different starting base models. You used the older llama 3 in axolotl and newer 3.1 for llama-factory.
base_model: meta-llama/Meta-Llama-3-8B
vs
model_name_or_path: meta-llama/Llama-3.1-8B
@xzuyn That's a typo when I cleaning up the Axolotl config. Sorry about that. I was using a local path with llama3.1-8b model with Axolotl. So both experiments were using the same llama3.1-8b as base model.
@winglian OP addressed this above.
Based on our latest experiments, one contributing factor is FSDP1 vs FSDP2. We suspect that this is from upstream HF's implementation. We'll just stay away from FSDP2 for now.
FSDP1 result
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval | 4|none | 0|inst_level_loose_acc |↑ |0.7530|± | N/A|
| | |none | 0|inst_level_strict_acc |↑ |0.7278|± | N/A|
| | |none | 0|prompt_level_loose_acc |↑ |0.6654|± |0.0203|
| | |none | 0|prompt_level_strict_acc|↑ |0.6396|± |0.0207|
FSDP2 result
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|--------------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval | 4|none | 0|inst_level_loose_acc |↑ |0.4065|± | N/A|
| | |none | 0|inst_level_strict_acc |↑ |0.3741|± | N/A|
| | |none | 0|prompt_level_loose_acc |↑ |0.2791|± |0.0193|
| | |none | 0|prompt_level_strict_acc|↑ |0.2458|± |0.0185|
Loss graphs
@weiran-work Thanks for sharing. Did you run these fsdp experiments on Axolotl? If yes, could we have the configs to repro? Perhaps there's some knob that's causing this discrepancy.
Otherwise, let us know how to repro
Yes, it is with Axolotl
This is the FSDP1 config
base_model: meta-llama/Llama-3.1-8B
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
chat_template: llama3
datasets:
- path: allenai/tulu-3-sft-mixture
split: train
type: chat_template
field_messages: messages
dataset_prepared_path: last_run_prepared
dataloader_num_workers: 128
dataloader_pin_memory: true
dataset_processes: 128
val_set_size: 0.05
output_dir: sft-llama3-packed-fsdp1
sequence_len: 4096
sample_packing: true
flash_attention: true
gradient_checkpointing: false
wandb_project:
wandb_entity:
wandb_name:
# max_steps: 10
gradient_accumulation_steps: 8
micro_batch_size: 2
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_torch_fused
learning_rate: 5e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
max_grad_norm: 1.0
lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
warmup_steps: 500
cosine_constant_lr_ratio: 0.0
bf16: true
tf32: true
logging_steps: 1
eval_strategy: "steps"
eval_steps: 200
save_strategy: "best"
metric_for_best_model: "eval_loss"
greater_is_better: false
save_total_limit: 8
save_only_model: true
resume_from_checkpoint:
special_tokens:
pad_token: <|finetune_right_pad_id|>
eos_token: <|eot_id|>
fsdp:
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: false
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: true
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: SHARD_GRAD_OP
fsdp_backward_prefetch: BACKWARD_PRE
For FSDP2 config, below is the only diff that replaced the above FSDP1 config
fsdp_version: 2
fsdp_config:
offload_params: false
cpu_ram_efficient_loading: true
auto_wrap_policy: TRANSFORMER_BASED_WRAP
transformer_layer_cls_to_wrap: LlamaDecoderLayer
state_dict_type: FULL_STATE_DICT
reshard_after_forward: true
btw, did you use FSDP2 with llama-factory as well?
@winglian llama-factory don't support fsdp2 official when we tested it a few weeks ago.
Hi @weiran-work. Thanks so much for reporting this issue and helping us reproduce it.
I don't think this issue is fundamentally related to FSDP1 vs. FSDP2. I believe it is due to how Accelerate handles model precision for FSDP - the default behaviour in Accelerate is to forcibly upcast the model to FP32 for both FSDP1 and FSDP2. We have disabled this behaviour for FSDP2 but not for FSDP1, as we are aiming to deprecate FSDP1 soon.
Could you try running your FSDP2 config using the fsdp2_fp32 branch? I've reverted the upcast-disabling behaviour for FSDP2 so you should see pretty much identical memory usage and loss curves between FSDP1 and FSDP2 - I've verified this on our end and you can see my loss curves below.
As for a solution, we can make it easier to configure FP32 fine-tuning as it's currently not possible to do so in axolotl without using my branch above - is this something you would find helpful? We historically have not focused on this feature as we find many users find the tradeoff of reduced performance for significantly less memory usage acceptable.
A general note on reproducability - it's worth fixing your seed and dataset_prepared_path in your config when comparing different runs, as this ensures different runs will use the same pre-processed dataset and random seed.
@weiran-work separately I'm impressed at your thoroughness in investigating this issue - we'd be interested in hearing more about how you're using Axolotl and to see if we could provide more support. Would you be up for a quick chat?
@SalmanMohammadi Thank you for root causing this! We'll give that branch a try end of this week or early next week.
And I personally think exposing that upcasting as an option would be great. The precision does matter especially for CPT or instruction tuning.
Happy to have a quick chat. Let me get back to you in a few days.