diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Wrong learning rate scheduler training step count for examples with multi-gpu when setting `--num_train_epochs`

Open geniuspatrick opened this issue 1 year ago • 4 comments

Describe the bug

I think there are still some problems with the learning rate scheduler. This is resolved when you set --max_train_steps, as discussed in #3954 , but not completely.

For example, the code snippet https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L816-L833 . I paste it here:

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
)

# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    unet, optimizer, train_dataloader, lr_scheduler
)

When setting --num_train_epochs instead of --max_train_steps, the calculation of num_update_steps_per_epoch is incorrect because train_dataloader has not yet been wrapped by accelerator.prepare. Consequently, args.max_train_steps is roughly num_processes times the actual value. This discrepancy leads to unintended values being passed into the get_scheduler function.

In fact, the logic here is quite confusing. It seems like a refactoring might be necessary.

Reproduction

accelerate launch --mixed_precision="fp16" train_text_to_image.py \
  ...
-  --max_train_steps=15000 \
+  --num_train_epochs=100 \
  ...

Logs

No response

System Info

  • diffusers version: 0.27.2
  • Platform: macOS-10.16-x86_64-i386-64bit
  • Python version: 3.9.17
  • PyTorch version (GPU?): 2.0.1 (False)
  • Huggingface_hub version: 0.20.3
  • Transformers version: 4.30.0
  • Accelerate version: 0.21.0
  • xFormers version: not installed
  • Using GPU in script?: no
  • Using distributed or parallel set-up in script?: no

Who can help?

@sayakpaul @yiyixuxu @eliphatfs

geniuspatrick avatar May 23 '24 07:05 geniuspatrick

Consequently, args.max_train_steps is roughly num_processes times the actual value.

Could you explain as to why do you think this is the case?

You seemed to have an idea of what you like the code block to look like. So, if you want to take a stab at PR, happy to review that too.

sayakpaul avatar May 23 '24 08:05 sayakpaul

Let's make a quick assumption:

  • length of dataset: 8
  • batch size: 1
  • gradient accumulation steps: 1
  • number of gpus(num_processes): 2
  • number of epochs(num_train_epochs): 1
  • max_train_steps(do not set): None

Before we do accelerator.prepare on the train_dataloader, it is created in the standalone training way, not the distributed training way. So, the len(train_dataloader)=8 ==> num_update_steps_per_epoch=8 ==> args.max_train_steps=8. However, we expect args.max_train_steps=4, right?

geniuspatrick avatar May 23 '24 08:05 geniuspatrick

Thank you. Would you be interested in a PR to fix this?

sayakpaul avatar May 23 '24 08:05 sayakpaul

I would do some more detailed verification first. Also, I hope @eliphatfs can help confirm that the issues mentioned above are correct.

geniuspatrick avatar May 23 '24 09:05 geniuspatrick

I tested that the scripts were working for step-based training. For epoch-based training I do think num_update_steps_per_epoch should be divided by the number of processes -- this value appears incorrect.

eliphatfs avatar May 25 '24 19:05 eliphatfs

Cool. I think we're clear on the bug. I will open a pr to fix the issue asap.

geniuspatrick avatar May 26 '24 04:05 geniuspatrick

the length of the dataloader should reveal the number of batches and not the number of samples. are we doing that here?

bghira avatar May 28 '24 23:05 bghira