axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

sub-par tuned model quality compared to LlamaFactory

Open weiran-work opened this issue 3 months ago • 19 comments

Please check that this issue hasn't been reported before.

  • [x] I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

I used allenai/tulu-3-sft-mixture dataset to tune pretrained Llama3.1-8b model for instruction following.

I tested both Axolotl and LlamaFactory. Using the same hyperparams, after tuning, the best saved checkpoint from Axolotl consistently under performs Llama Factory by a large margin on the IFeval metric.

LlamaFactory:

|ifeval|      4|none  |     0|inst_level_loose_acc   |↑  |0.7974|±  |   N/A|
|      |       |none  |     0|inst_level_strict_acc  |↑  |0.7698|±  |   N/A|
|      |       |none  |     0|prompt_level_loose_acc |↑  |0.7246|±  |0.0192|
|      |       |none  |     0|prompt_level_strict_acc|↑  |0.6987|±  |0.0197|

Axolotl:

|Tasks |Version|Filter|n-shot|        Metric         |   |Value |   |Stderr|
|------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval|      4|none  |     0|inst_level_loose_acc   |↑  |0.6403|±  |   N/A|
|      |       |none  |     0|inst_level_strict_acc  |↑  |0.6151|±  |   N/A|
|      |       |none  |     0|prompt_level_loose_acc |↑  |0.5471|±  |0.0214|
|      |       |none  |     0|prompt_level_strict_acc|↑  |0.5176|±  |0.0215|

I expect some diffs between different frameworks, but did not expect the gap to be this large

Current behaviour

I'd expect the tuned model's quality to be similar to the ones from LlamaFactory

Steps to reproduce

axolotl train <config-file>

Config yaml

base_model: meta-llama/Meta-Llama-3-8B

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

chat_template: jinja
chat_template_jinja: <path/to/tulu3.jinja>
datasets:
  - path: allenai/tulu-3-sft-mixture
    split: train
    type: chat_template
    field_messages: messages
dataset_prepared_path: last_run_prepared

dataloader_num_workers: 128
dataloader_pin_memory: true 
dataset_processes: 128

val_set_size: 0.05
output_dir: sft-llama3-no-packing-v1

sequence_len: 4096
sample_packing: false
flash_attention: true
gradient_checkpointing: false

wandb_project:
wandb_entity:
wandb_name: 

gradient_accumulation_steps: 8
micro_batch_size: 2
eval_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
learning_rate: 5e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
max_grad_norm: 1.0

lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
warmup_steps: 500
cosine_constant_lr_ratio: 0.0

bf16: true
tf32: true

logging_steps: 1
eval_strategy: "steps"
eval_steps: 1000
save_strategy: "best"
metric_for_best_model: "eval_loss"
greater_is_better: false
save_total_limit: 3

resume_from_checkpoint:

special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

fsdp_version: 2
fsdp_config:
  offload_params: false
  cpu_ram_efficient_loading: true
  auto_wrap_policy: TRANSFORMER_BASED_WRAP
  transformer_layer_cls_to_wrap: LlamaDecoderLayer
  state_dict_type: SHARDED_STATE_DICT
  reshard_after_forward: true


tulu3.jinja

{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n'  + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n'  + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}

Possible solution

No response

Which Operating Systems are you using?

  • [ ] Linux
  • [ ] macOS
  • [ ] Windows

Python Version

3.11.13

axolotl branch-commit

main/050210e637a7ca2fdb65491eced13bd4d1ce5d10

Acknowledgements

  • [x] My issue title is concise, descriptive, and in title casing.
  • [x] I have searched the existing issues to make sure this bug has not been reported yet.
  • [x] I am using the latest version of axolotl.
  • [x] I have provided enough information for the maintainers to reproduce and diagnose the issue.

weiran-work avatar Sep 11 '25 22:09 weiran-work

That's a very interesting observation. We both use transformers under the hood (just different speed/vram optimizations), so I'm surprised the differences are this much.

Could you provide a repro for running on llama factory too?

NanoCode012 avatar Sep 12 '25 04:09 NanoCode012

thanks @NanoCode012 Here's an example llama-factory config we used.

### model
model_name_or_path: meta-llama/Llama-3.1-8B
trust_remote_code: true

### method
stage: sft
do_train: true
finetuning_type: full
# deepspeed: examples/deepspeed/ds_z2_config.json  # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
include_effective_tokens_per_second: true
flash_attn: fa2
enable_liger_kernel: true
packing: false
include_num_input_tokens_seen: true

### dataset
dataset: tulu_3_sft_mixture
template: "{% for message in messages %}{% if message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + '\n' }}{% elif message['role'] == 'assistant' %}{% if not loop.last %}{{ '<|assistant|>\n'  + message['content'] + eos_token + '\n' }}{% else %}{{ '<|assistant|>\n'  + message['content'] + eos_token }}{% endif %}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|assistant|>\n' }}{% endif %}{% endfor %}"
cutoff_len: 4096
# max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16
dataloader_num_workers: 16

### output
output_dir: saves/llama3p1-8b/full/fa2_liger_fsdp_chat_template
logging_steps: 10
save_steps: 5000
plot_loss: true
overwrite_output_dir: true
save_only_model: false
report_to: wandb  # choices: [none, wandb, tensorboard, swanlab, mlflow]
run_name: llama3p1-8b-sft-fa2_liger_fsdp

### train
per_device_train_batch_size: 1
gradient_accumulation_steps: 16
learning_rate: 5.0e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
num_train_epochs: 4.0
lr_scheduler_type: cosine
warmup_ratio: 0.03
bf16: true
ddp_timeout: 180000000
resume_from_checkpoint: null

### eval
# eval_dataset: alpaca_en_demo
val_size: 0.05
per_device_eval_batch_size: 4
eval_strategy: steps
eval_steps: 500

weiran-work avatar Sep 12 '25 21:09 weiran-work

meta-llama/Meta-Llama-3-8B and meta-llama/Llama-3.1-8B are different models. Unless that was a typo?

xzuyn avatar Sep 13 '25 11:09 xzuyn

@xzuyn That's a typo when I cleaning up the Axolotl config. Sorry about that. I was using a local path with llama3.1-8b model with Axolotl. So both experiments were using the same llama3.1-8b as base model.

weiran-work avatar Sep 13 '25 14:09 weiran-work

Thanks for the repro. We will look into this

NanoCode012 avatar Sep 16 '25 07:09 NanoCode012

heyy i tried to repro this on script on llama factory do i need to register the chat template manually ? also the regarding dataset i need to add it to dataset.json file can you also share the functions and everything ?

ved1beta avatar Sep 17 '25 05:09 ved1beta

@ved1beta Thanks for trying

do i need to register the chat template manually

You can add the template to the tokenizer_config.json

dataset.json file can you also share the functions

This is what's added

"tulu_3_sft_mixture": {
    "hf_hub_url": "allenai/tulu-3-sft-mixture",
    "formatting": "sharegpt",
    "columns": {
      "messages": "messages"
    },
    "tags": {
      "role_tag": "role",
      "content_tag": "content",
      "user_tag": "user",
      "assistant_tag": "assistant",
      "system_tag": "system"
    }
  },

weiran-work avatar Sep 17 '25 19:09 weiran-work

@weiran-work How GPUs for FSDP? H100? This affects reproducibility on our end since it affects effective batch size. thanks!

winglian avatar Sep 19 '25 15:09 winglian

@winglian We used 8xH100, effective batch size is 128

weiran-work avatar Sep 19 '25 18:09 weiran-work

@weiran-work I think the problem lies in the different starting base models. You used the older llama 3 in axolotl and newer 3.1 for llama-factory. base_model: meta-llama/Meta-Llama-3-8B vs model_name_or_path: meta-llama/Llama-3.1-8B

winglian avatar Sep 21 '25 11:09 winglian

@xzuyn That's a typo when I cleaning up the Axolotl config. Sorry about that. I was using a local path with llama3.1-8b model with Axolotl. So both experiments were using the same llama3.1-8b as base model.

@winglian OP addressed this above.

SalmanMohammadi avatar Sep 21 '25 12:09 SalmanMohammadi

Based on our latest experiments, one contributing factor is FSDP1 vs FSDP2. We suspect that this is from upstream HF's implementation. We'll just stay away from FSDP2 for now.

FSDP1 result

|    Tasks     |Version|Filter|n-shot|        Metric         |   |Value |   |Stderr|
|--------------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval        |      4|none  |     0|inst_level_loose_acc   |↑  |0.7530|±  |   N/A|
|              |       |none  |     0|inst_level_strict_acc  |↑  |0.7278|±  |   N/A|
|              |       |none  |     0|prompt_level_loose_acc |↑  |0.6654|±  |0.0203|
|              |       |none  |     0|prompt_level_strict_acc|↑  |0.6396|±  |0.0207|

FSDP2 result

|    Tasks     |Version|Filter|n-shot|        Metric         |   |Value |   |Stderr|
|--------------|------:|------|-----:|-----------------------|---|-----:|---|------|
|ifeval        |      4|none  |     0|inst_level_loose_acc   |↑  |0.4065|±  |   N/A|
|              |       |none  |     0|inst_level_strict_acc  |↑  |0.3741|±  |   N/A|
|              |       |none  |     0|prompt_level_loose_acc |↑  |0.2791|±  |0.0193|
|              |       |none  |     0|prompt_level_strict_acc|↑  |0.2458|±  |0.0185|

Loss graphs

Image

weiran-work avatar Sep 22 '25 17:09 weiran-work

@weiran-work Thanks for sharing. Did you run these fsdp experiments on Axolotl? If yes, could we have the configs to repro? Perhaps there's some knob that's causing this discrepancy.

Otherwise, let us know how to repro

NanoCode012 avatar Sep 23 '25 13:09 NanoCode012

Yes, it is with Axolotl

This is the FSDP1 config

base_model: meta-llama/Llama-3.1-8B

plugins:
  - axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true

chat_template: llama3
datasets:
  - path: allenai/tulu-3-sft-mixture
    split: train
    type: chat_template
    field_messages: messages
dataset_prepared_path: last_run_prepared

dataloader_num_workers: 128
dataloader_pin_memory: true 
dataset_processes: 128

val_set_size: 0.05
output_dir: sft-llama3-packed-fsdp1

sequence_len: 4096
sample_packing: true
flash_attention: true
gradient_checkpointing: false

wandb_project: 
wandb_entity: 
wandb_name: 

# max_steps: 10
gradient_accumulation_steps: 8
micro_batch_size: 2
eval_batch_size: 1
num_epochs: 3
optimizer: adamw_torch_fused
learning_rate: 5e-6
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
adam_epsilon: 1e-5
max_grad_norm: 1.0

lr_scheduler: cosine
cosine_min_lr_ratio: 0.1
warmup_steps: 500
cosine_constant_lr_ratio: 0.0

bf16: true
tf32: true

logging_steps: 1
eval_strategy: "steps"
eval_steps: 200
save_strategy: "best"
metric_for_best_model: "eval_loss"
greater_is_better: false
save_total_limit: 8
save_only_model: true

resume_from_checkpoint:

special_tokens:
  pad_token: <|finetune_right_pad_id|>
  eos_token: <|eot_id|>

fsdp:
  - auto_wrap
fsdp_config:
  fsdp_limit_all_gathers: false
  fsdp_sync_module_states: true
  fsdp_offload_params: false
  fsdp_use_orig_params: true
  fsdp_cpu_ram_efficient_loading: true
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sharding_strategy: SHARD_GRAD_OP
  fsdp_backward_prefetch: BACKWARD_PRE

For FSDP2 config, below is the only diff that replaced the above FSDP1 config

fsdp_version: 2
fsdp_config:
  offload_params: false
  cpu_ram_efficient_loading: true
  auto_wrap_policy: TRANSFORMER_BASED_WRAP
  transformer_layer_cls_to_wrap: LlamaDecoderLayer
  state_dict_type: FULL_STATE_DICT
  reshard_after_forward: true

weiran-work avatar Sep 23 '25 18:09 weiran-work

btw, did you use FSDP2 with llama-factory as well?

winglian avatar Sep 24 '25 13:09 winglian

@winglian llama-factory don't support fsdp2 official when we tested it a few weeks ago.

weiran-work avatar Sep 24 '25 18:09 weiran-work

Hi @weiran-work. Thanks so much for reporting this issue and helping us reproduce it.

I don't think this issue is fundamentally related to FSDP1 vs. FSDP2. I believe it is due to how Accelerate handles model precision for FSDP - the default behaviour in Accelerate is to forcibly upcast the model to FP32 for both FSDP1 and FSDP2. We have disabled this behaviour for FSDP2 but not for FSDP1, as we are aiming to deprecate FSDP1 soon.

Could you try running your FSDP2 config using the fsdp2_fp32 branch? I've reverted the upcast-disabling behaviour for FSDP2 so you should see pretty much identical memory usage and loss curves between FSDP1 and FSDP2 - I've verified this on our end and you can see my loss curves below.

Image

As for a solution, we can make it easier to configure FP32 fine-tuning as it's currently not possible to do so in axolotl without using my branch above - is this something you would find helpful? We historically have not focused on this feature as we find many users find the tradeoff of reduced performance for significantly less memory usage acceptable.

A general note on reproducability - it's worth fixing your seed and dataset_prepared_path in your config when comparing different runs, as this ensures different runs will use the same pre-processed dataset and random seed.

SalmanMohammadi avatar Sep 26 '25 16:09 SalmanMohammadi

@weiran-work separately I'm impressed at your thoroughness in investigating this issue - we'd be interested in hearing more about how you're using Axolotl and to see if we could provide more support. Would you be up for a quick chat?

SalmanMohammadi avatar Sep 26 '25 17:09 SalmanMohammadi

@SalmanMohammadi Thank you for root causing this! We'll give that branch a try end of this week or early next week.

And I personally think exposing that upcasting as an option would be great. The precision does matter especially for CPT or instruction tuning.

Happy to have a quick chat. Let me get back to you in a few days.

weiran-work avatar Sep 30 '25 17:09 weiran-work