CogVideo icon indicating copy to clipboard operation
CogVideo copied to clipboard

LR Scheduler Bugs in Finetuning Code

Open aHapBean opened this issue 9 months ago • 2 comments

Hi. I’ve observed that the learning rate tends to be higher during fine-tuning when using more GPUs. For instance, the group with 4 GPUs (gradient_accumulation_steps=1) and the group with 2 GPUs (gradient_accumulation_steps=2) exhibit significantly different learning rates. Notably, the learning rate for the 4-GPU setup fails to converge in the end, resulting in poorer performance. The learning rate figure:

Image

After reviewing the code, I found abnormal issues in the directory /finetune/trainer.py:

        num_update_steps_per_epoch = math.ceil(len(self.data_loader) / self.args.gradient_accumulation_steps)
        if self.args.train_steps is None:
            self.args.train_steps = self.args.train_epochs * num_update_steps_per_epoch
            self.state.overwrote_max_train_steps = True

        use_deepspeed_lr_scheduler = (
            self.accelerator.state.deepspeed_plugin is not None
            and "scheduler" in self.accelerator.state.deepspeed_plugin.deepspeed_config
        )
        total_training_steps = self.args.train_steps * self.accelerator.num_processes
        num_warmup_steps = self.args.lr_warmup_steps * self.accelerator.num_processes

        if use_deepspeed_lr_scheduler:
            from accelerate.utils import DummyScheduler

            lr_scheduler = DummyScheduler(
                name=self.args.lr_scheduler,
                optimizer=optimizer,
                total_num_steps=total_training_steps,
                num_warmup_steps=num_warmup_steps,
            )
        else:
            lr_scheduler = get_scheduler(
                name=self.args.lr_scheduler,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_steps,
                num_training_steps=total_training_steps,
                num_cycles=self.args.lr_num_cycles,
                power=self.args.lr_power,
            )

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

The line total_training_steps = self.args.train_steps * self.accelerator.num_processes seems incorrect. Training steps should be based on self.args.train_steps directly, with the GPUs being used for parallelization, rather than multiplying by self.accelerator.num_processes.

Additionally, in the following code:

And:

            for step, batch in enumerate(self.data_loader):
                logger.debug(f"Starting step {step + 1}")
                logs = {}

                with accelerator.accumulate(models_to_accumulate):
                    # These weighting schemes use a uniform timestep sampling and instead post-weight the loss
                    loss = self.compute_loss(batch)
                    if isinstance(loss, list):
                        if len(loss) == 3:
                            assert self.args.loss == 'cosine_matrix_margin' or self.args.loss == 'cosine_temporal_margin' or self.args.loss == 'cosine_margin_temporal_margin' or self.args.loss == 'temporal_margin_no_split' or self.args.loss == 'temporal_margin_no_split_cls_temporal'
                            loss_diffusion = loss[0]
                            loss_align = loss[1]
                            loss_matrix_distance = loss[2]
                            if loss_align is not None:
                                loss = loss_diffusion + self.args.proj_eff * (loss_align + loss_matrix_distance)
                            else:
                                loss = loss_diffusion + self.args.proj_eff * loss_matrix_distance
                        else:
                            # naive cosine loss
                            loss_diffusion = loss[0]
                            loss_align = loss[1]
                            loss_matrix_distance = None
                            loss = loss_diffusion + self.args.proj_eff * loss_align 
                    else:
                        loss_align = None 
                        loss_diffusion = None
                        loss_matrix_distance = None
                        
                    accelerator.backward(loss)

                    if accelerator.sync_gradients:
                        if accelerator.distributed_type == DistributedType.DEEPSPEED:
                            grad_norm = self.components.transformer.get_global_grad_norm()
                            # In some cases the grad norm may not return a float
                            if torch.is_tensor(grad_norm):
                                grad_norm = grad_norm.item()
                        else:
                            grad_norm = accelerator.clip_grad_norm_(
                                self.components.transformer.parameters(), self.args.max_grad_norm
                            )
                            if torch.is_tensor(grad_norm):
                                grad_norm = grad_norm.item()

                        logs["grad_norm"] = grad_norm

                    self.optimizer.step()
                    self.lr_scheduler.step()
                    self.optimizer.zero_grad()

It appears that the optimizer.step() and lr_scheduler.step() are being called at every step, regardless of the value of self.args.gradient_accumulation_steps. Shouldn’t the step() functions be called only when step % self.args.gradient_accumulation_steps == 0?

aHapBean avatar Mar 25 '25 02:03 aHapBean

Thank you for pointing it out. Indeed, accelerate may handle settings for multi-card parallelism on its own, so the scheduler should be configured according to single-card settings before preparing the scheduler.

Regarding the second point, because in the environment with accelerator.accumulate(models_to_accumulate), accelerate will handle the gradient accumulation process on its own, as seen here.

Thank you again for pointing it out, and you are welcome to submit a PR to fix this error.

OleehyO avatar Mar 25 '25 08:03 OleehyO

Thank you for your response. My questions are resolved.

I will submit a PR in the next few days.

aHapBean avatar Mar 25 '25 08:03 aHapBean