LLaMA-Factory icon indicating copy to clipboard operation
LLaMA-Factory copied to clipboard

Erroneous/high loss with DeepSpeed Zero3 and bf16

Open mnmueller opened this issue 4 months ago • 0 comments

Reminder

  • [X] I have read the README and searched the existing issues.

Reproduction

We observe that across models and datasets using Zero3 with bf16 yields much higher losses than Zero2 with bf16 and Zero3 with fp16 or fp32 (the latter obtaining very similar losses). Please follow the steps below for reproduction.

The top trace is Z3 with bf16 the bottom four are Z2 bf16 (highest), Z3 fp16, Z3 fp32, and Z2 fp32 where the last three yield virtually identical results. image

This error becomes significantly more pronounced for large context lengths.

First set up the environment:

git clone [email protected]:hiyouga/LLaMA-Factory.git
cd LLaMA-Factory

conda create --name debugZ3 python==3.10
conda activate debugZ3
pip install -r requirements.txt
pip install deepspeed

Now create the following config files for DeepSpeed: zero3_bf16.json:

{
  "zero_optimization": {
    "stage": 3
  },
  "zero3_init_flag": false,
  "bf16": {
    "enabled": true
  },
  "fp16": {
    "enabled": false
  },
  "gradient_accumulation_steps": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false,
  "dist_init_required":true
}

zero3_fp16.json:

{
  "zero_optimization": {
    "stage": 3
  },
  "bf16": {
    "enabled": false
  },
  "fp16": {
        "enabled": true,
        "auto_cast": false,
        "loss_scale": 0,
        "initial_scale_power": 16,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "consecutive_hysteresis": false,
        "min_loss_scale": 1
    },
  "gradient_accumulation_steps": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}

zero2_bf16.json

{
  "zero_optimization": {
    "stage": 2
  },
  "bf16": {
    "enabled": true
  },
  "fp16": {
    "enabled": false
  },
  "gradient_accumulation_steps": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}

Run LLamaFactory for these settings. We use slurm and torchrun as a launcher for phi-2 and wiki_demo: With Zero3 bf16:

srun --jobid "$SLURM_JOBID" \
    bash -c 'torchrun --nproc_per_node "$GPUS_PER_NODE" --nnodes "$SLURM_NNODES" --node_rank "$SLURM_PROCID" \
    --master_addr "$MASTER_ADDR" --master_port "$MASTER_PORT" \
    ./LLaMA-Factory/src/train_bash.py \
    --deepspeed zero3_aa.json \
    --stage pt \
    --do_train \
    --model_name_or_path microsoft/phi-2 \
    --dataset wiki_demo \
    --finetuning_type full \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --lr_scheduler_type cosine \
    --logging_step 1 \
    --learning_rate 1e-5 \
    --warmup_steps 100 \
    --num_train_epochs 1.0 \
    --bf16 \
    --cutoff_len 2048'

With Zero3 fp16:

srun --jobid "$SLURM_JOBID" \
    bash -c ' torchrun --nproc_per_node "$GPUS_PER_NODE" --nnodes "$SLURM_NNODES" --node_rank "$SLURM_PROCID" \
    --master_addr "$MASTER_ADDR" --master_port "$MASTER_PORT" \
    ./LLaMA-Factory/src/train_bash.py \
    --deepspeed zero3_fp16.json \
    --stage pt \
    --do_train \
    --model_name_or_path microsoft/phi-2 \
    --dataset wiki_demo \
    --finetuning_type full \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --lr_scheduler_type cosine \
    --logging_step 1 \
    --learning_rate 1e-5 \
    --warmup_steps 100 \
    --num_train_epochs 1.0 \
    --fp16 \
    --cutoff_len 2048

With Zero2 bf16:

srun --jobid "$SLURM_JOBID" \
    bash -c 'torchrun --nproc_per_node "$GPUS_PER_NODE" --nnodes "$SLURM_NNODES" --node_rank "$SLURM_PROCID" \
    --master_addr "$MASTER_ADDR" --master_port "$MASTER_PORT" \
    ./LLaMA-Factory/src/train_bash.py \
    --deepspeed zero2_bf16.json \
    --stage pt \
    --do_train \
    --model_name_or_path microsoft/phi-2 \
    --dataset wiki_demo \
    --finetuning_type full \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 2 \
    --lr_scheduler_type cosine \
    --logging_step 1 \
    --learning_rate 1e-5 \
    --warmup_steps 100 \
    --num_train_epochs 1.0 \
    --bf16 \
    --cutoff_len 2048'

Expected behavior

We would expect Zero2 and Zero3 to produce identical losses even when using bf16.

System Info

  • transformers version: 4.37.2
  • Platform: Linux-5.15.0-1048-oracle-x86_64-with-glibc2.35
  • Python version: 3.10.0
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: 0.27.2
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: 8 x 40GB A100
  • Using distributed or parallel set-up in script?: using torchrun via slurm as described

Others

No response

mnmueller avatar Feb 15 '24 16:02 mnmueller