transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Inconsistent training steps between Trainer and DeepSpeed

Open fenchri opened this issue 1 year ago • 5 comments

System Info

  • transformers version: 4.26.0
  • Platform: Linux-5.4.0-136-generic-x86_64-with-glibc2.17
  • Python version: 3.8.16
  • Huggingface_hub version: 0.12.1
  • PyTorch version (GPU?): 1.12.0+cu113 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: Yes

DeepSpeed general environment info: torch install path ............... ['/home/fenia/anaconda3/envs/benchmark/lib/python3.8/site-packages/torch'] torch version .................... 1.12.0+cu113 deepspeed install path ........... ['/home/fenia/anaconda3/envs/benchmark/lib/python3.8/site-packages/deepspeed'] deepspeed info ................... 0.8.1, unknown, unknown torch cuda version ............... 11.3 torch hip version ................ None nvcc version ..................... 11.8 deepspeed wheel compiled w. ...... torch 1.12, cuda 11.3

Who can help?

@stas00

Information

  • [X] The official example scripts
  • [ ] My own modified scripts

Tasks

  • [X] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • [ ] My own task or dataset (give details below)

Reproduction

Hello!

There seems to be some incosistency in the number of training steps when using DeepSpeed with HF trainer. It looks like DeepSpeed is doing things correctly but ends up training more steps in order to match Trainer. They both continue training even after learning rate has dropped to 0.

From the official examples:

ds_config_zero2={
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "bf16": {
        "enabled": "auto"
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto",
            "total_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}
DISTRIBUTED_ARGS="--nproc_per_node 2 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 6000"
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
	run_clm.py \
	--model_name_or_path gpt2 \
	--dataset_name wikitext \
	--dataset_config_name wikitext-103-raw-v1 \
	--per_device_train_batch_size 2 \
	--per_device_eval_batch_size 1 \
	--do_train \
	--output_dir /tmp/test-clm2 \
	--max_train_samples=148 \
	--gradient_accumulation_steps=16 \
	--overwrite_output_dir \
	--max_steps=200 \
	--logging_steps=10 \
	--deepspeed="ds_config_zero2.json"

I attach the training output: output.txt

The same behavior is observed even if training with Trainer+DeepSpeed on a single GPU.

Expected behavior

Expected number of steps should match between Trainer and DeepSpeed logging.

Thank you very much in advance!

fenchri avatar Mar 10 '23 11:03 fenchri

Thank you for the great and easy to reproduce report, @fenchri

Indeed, you found a grad accumulation bug in HF Trainer. This is not an bug in DeepSpeed or its integration.

I did:

diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 344523842..a75110ee9 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1886,6 +1886,7 @@ class Trainer:
                 if step % args.gradient_accumulation_steps == 0:
                     self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

+                print(f"HF STEP {step+1}")
                 if (
                     ((step + 1) % args.gradient_accumulation_steps != 0)
                     and args.local_rank != -1

and now running w/o deepspeed:

python -m torch.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 \
--master_addr localhost --master_port 6000 \
examples/pytorch/language-modeling/run_clm.py \
--model_name_or_path sshleifer/tiny-gpt2 --dataset_name wikitext \
--dataset_config_name wikitext-103-raw-v1 --per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 --do_train --block_size 10 --output_dir \
output_dir --max_train_samples=148 --gradient_accumulation_steps=16 \
--overwrite_output_dir --max_steps=10 --logging_steps=1

Since you set --max_train_samples=148 --gradient_accumulation_steps=16 at step 8->9 the dataset wraps over, but the grad accum counter ignores the wrapping and waits for ((step + 1) % args.gradient_accumulation_steps != 0)

so when we run it, we get:

[skipped the first 7 grad acc cycles]
{'loss': 10.823, 'learning_rate': 1.5e-05, 'epoch': 1.65}                                                                                 
 70%|███████████████████████████████████████████████████████████████████████▍                              | 7/10 [00:00<00:00, 10.01it/s]
HF STEP 49
HF STEP 50
HF STEP 51
HF STEP 52
HF STEP 53
HF STEP 54
HF STEP 55
HF STEP 56
HF STEP 57
HF STEP 58
HF STEP 59
HF STEP 60
HF STEP 61
HF STEP 62
HF STEP 63
HF STEP 64
{'loss': 10.8266, 'learning_rate': 1e-05, 'epoch': 1.86}                                                                                  
 80%|█████████████████████████████████████████████████████████████████████████████████▌                    | 8/10 [00:01<00:00, 10.01it/s]
HF STEP 65
HF STEP 66
HF STEP 67
HF STEP 68
HF STEP 69
HF STEP 70
HF STEP 71
HF STEP 72
HF STEP 73
HF STEP 74
HF STEP 1
HF STEP 2
HF STEP 3
HF STEP 4
HF STEP 5
HF STEP 6
HF STEP 7
HF STEP 8
HF STEP 9
HF STEP 10
HF STEP 11
HF STEP 12
HF STEP 13
HF STEP 14
HF STEP 15
HF STEP 16
{'loss': 17.593, 'learning_rate': 5e-06, 'epoch': 2.22}                                                                                   
 90%|███████████████████████████████████████████████████████████████████████████████████████████▊          | 9/10 [00:01<00:00, 11.05it/s]
HF STEP 17
HF STEP 18
HF STEP 19
HF STEP 20
HF STEP 21
HF STEP 22
HF STEP 23
HF STEP 24
HF STEP 25
HF STEP 26
HF STEP 27
HF STEP 28
HF STEP 29
HF STEP 30
HF STEP 31
HF STEP 32
{'loss': 10.8249, 'learning_rate': 0.0, 'epoch': 2.43}  

you can see that between iteration 8 and 9 there are more than 16 grad accumulation steps happening.


Until this is fixed, specifically to your needs, @fenchri - as long as you're using deepspeed the grad accumulation is performed correctly since it performs it on its own. But you end up running more than steps than specified.

stas00 avatar Mar 10 '23 23:03 stas00

Hmm, actually looking at earlier steps, this appears to be odd as well:

{'loss': 10.8252, 'learning_rate': 3e-05, 'epoch': 0.86}                                                                                  
 40%|████████████████████████████████████████▊                                                             | 4/10 [00:00<00:01,  5.14it/s]
HF STEP 65
HF STEP 66
HF STEP 67
HF STEP 68
HF STEP 69
HF STEP 70
HF STEP 71
HF STEP 72
HF STEP 73
HF STEP 74
HF STEP 75
HF STEP 76
HF STEP 77
HF STEP 78
HF STEP 79
HF STEP 80
HF STEP 81
HF STEP 82
HF STEP 83
HF STEP 84
HF STEP 85
HF STEP 86
HF STEP 87
HF STEP 88
HF STEP 89
HF STEP 90

it did 9 additional dataset pulls here as well (25 instead of 16), and this is not at the grad accum boundary

edit: ah, it's because bs=2, so it hits the rollover already at step 4->5, that's why.

stas00 avatar Mar 11 '23 00:03 stas00

ok, actually I came up with a fix, will push shortly for you to try

Please try https://github.com/huggingface/transformers/pull/22098

stas00 avatar Mar 11 '23 00:03 stas00

Thanks @stas00 for having a look and apologies for the late reply. Indeed, the fix resolves the issue! :tada:

On a related note, the computation happening here seems to chop num_update_steps_per_epoch even if even if drop_last is False. This results in having 100 training epochs instead of 87, which then gets printed here.

Nevertheless, with the current fix the training stops at the desired number of steps, so should be fine :)

I am happy to open another issue related to this though if you think is needed :)

Thank you!

fenchri avatar Mar 11 '23 15:03 fenchri

Thanks @stas00 for having a look and apologies for the late reply. Indeed, the fix resolves the issue! tada

excellent! Thank you for testing the PR, @fenchri

I am happy to open another issue related to this though if you think is needed :)

yes, please. One Issue at a time.

stas00 avatar Mar 11 '23 17:03 stas00