axolotl icon indicating copy to clipboard operation
axolotl copied to clipboard

fsdp with `cpu_ram_efficient_loading=false` results in NaN loss and gradient values

Open Co1lin opened this issue 4 months ago • 1 comments

Please check that this issue hasn't been reported before.

  • [x] I searched previous Bug Reports didn't find any similar reports.

Expected Behavior

For the configuration below, only changing cpu_ram_efficient_loading to false leads to NaN loss and gradient values. While it is ok to use cpu_ram_efficient_loading: true with bf16: true, using fp8: true requires cpu_ram_efficient_loading to be false due to this.

seed: 42
# Allow overwrite yml config using from cli
strict: false
# Resume from a specific checkpoint dir
# resume_from_checkpoint:
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# This can also be a relative path to a model on disk
base_model: Qwen/Qwen3-Coder-30B-A3B-Instruct
# base_model: Qwen/Qwen3-8B
trust_remote_code: true
# Where to save the full-finetuned model to
output_dir: train/outputs/sft_stack_test_8k

# wandb configuration if you're using it
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
# wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: ar_sft # Your wandb project name
wandb_entity: cu_colin_org # A wandb Team name if using a Team
# wandb_watch:
wandb_name: sft_stack_30k # Set the name of your wandb run
wandb_run_id: # Set the ID of your wandb run
# wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training

# Training hyperparameters
sequence_len: 20480
# dp_replicate_size: 1
# Total: dp_shard_size × tensor_parallel_size × context_parallel_size = #GPUs on a single node
# dp_shard_size: 8         # FSDP across ? GPUs
# tensor_parallel_size: 1  # TP across ? GPUs
# context_parallel_size: 1 # CP across ? GPUs
flash_attention: true
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

# global batch size = micro_batch_size * dp_shard_size * gradient_accumulation_steps
micro_batch_size: 4
eval_batch_size: 4
gradient_accumulation_steps: 4

num_epochs: 2
# warmup_steps: 100  # cannot use with warmup_ratio
warmup_ratio: 0.05  # cannot use with warmup_steps
learning_rate: 2.0e-5 # 0.00003
logging_steps: 1
eval_steps: 0.1 # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
save_steps: # Leave empty to save at each epoch
# saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: 1000
gc_steps: 1

val_set_size: 0.05
# A list of one or more datasets to finetune the model with
datasets:
  # Using chat template  
  - path: tatsu-lab/alpaca
    type: alpaca

# https://datascience.stackexchange.com/questions/24511/why-should-the-data-be-shuffled-for-machine-learning-tasks/24524#24524
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
shuffle_merged_datasets: true
# Deduplicates datasets and test_datasets with identical entries.
dataset_exact_deduplication: false
# The name of the chat template to use for training, following values are supported:
# chat_template: tokenizer_default
# Changes the default system message
# default_system_message: A chat between a curious User and an artificial intelligence Bot. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path
dataset_prepared_path: train/axolotl_datasets/last_run_prepared
dataset_processes: 64

# Save model as safetensors (require safetensors package)
# save_safetensors: true

# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
# gradient_checkpointing: true
# additional kwargs to pass to the trainer for gradient checkpointing
# gradient_checkpointing_kwargs:
#   use_reentrant: false

optimizer: adamw_torch_fused
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: cosine # 'one_cycle' | 'log_sweep' | empty for cosine
# lr_scheduler_kwargs:
# cosine_min_lr_ratio: 0.1 # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
# cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)

plugins:
  - axolotl.integrations.liger.LigerPlugin
  - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true


bf16: true
tf32: true
# fp8: true
# fp8_enable_fsdp_float8_all_gather: true
torch_compile: true

# FSDP configuration
fsdp_version: 2
fsdp_config:
  offload_params: false
  cpu_ram_efficient_loading: true
  auto_wrap_policy: TRANSFORMER_BASED_WRAP
  transformer_layer_cls_to_wrap: Qwen3MoeDecoderLayer # Qwen3MoeDecoderLayer
  state_dict_type: FULL_STATE_DICT
  reshard_after_forward: true
  activation_checkpointing: true

Current behaviour

NaN loss and gradient values.

Steps to reproduce

Launch training with the config file above.

Config yaml


Possible solution

No response

Which Operating Systems are you using?

  • [x] Linux
  • [ ] macOS
  • [ ] Windows

Python Version

3.12.8

axolotl branch-commit

axolotl[deepspeed, ring-flash-attn] v0.12.1

Acknowledgements

  • [x] My issue title is concise, descriptive, and in title casing.
  • [x] I have searched the existing issues to make sure this bug has not been reported yet.
  • [x] I am using the latest version of axolotl.
  • [x] I have provided enough information for the maintainers to reproduce and diagnose the issue.

Co1lin avatar Aug 17 '25 23:08 Co1lin

Hey, FSDP2 with cpu_ram_efficient_loading should work in Axolotl. Could you let me know if you've given it a try?

NanoCode012 avatar Sep 03 '25 09:09 NanoCode012