CogVideo icon indicating copy to clipboard operation
CogVideo copied to clipboard

Fine tuning loss become NaN after some steps

Open PR-Ryan opened this issue 1 year ago • 2 comments

Hi authors,

I am fine-tuning the cogvideo-2b model with LoRA. I have added a new loss function with a small weight to the original diffusion loss. Initially, the training seems to work fine, but after several steps (specifically around 1.6K steps, with a batch size of 24), the diffusion loss, new loss, and total training loss all become NaN.

Any insights or suggestions on what might be causing this issue would be greatly appreciated. image

PR-Ryan avatar Sep 21 '24 10:09 PR-Ryan

Here are a few key differences in the diffusers framework used in our publicly released SAT fine-tuning code: LoRA weights have a rank parameter, with the 2B transformer model defaulting to a rank of 128, and the 5B transformer model defaulting to a rank of 256. The lora_scale is calculated as alpha / lora_r, where alpha is typically set to 1 during SAT training to ensure stability and prevent underflow. Higher rank offers better expressiveness, but it also demands more memory and results in longer training times.

glide-the avatar Sep 22 '24 04:09 glide-the

Thank you for your detailed reply.

I am indeed using the SAT code, and I’ve kept the rank parameter at 128, unchanged. Given that the loss becomes NaN, I’m wondering if there might be a division by zero occurring somewhere.

I also noticed in SwissArmyTransformer/sat/training/deepspeed_training.py under the train_step function, the metrics in the loss dictionary are checked and an all_reduce operation is performed. Could this step potentially contribute to the NaN issue? Any further insights would be much appreciated.

SwissArmyTransformer/sat/training/deepspeed_training.py

` def train_step(data_iterator, model, optimizer, lr_scheduler, args, timers, hooks=None, single_step=False, **kwargs): if hooks is None: hooks = {} lm_loss_total, metrics_total, count, metrics_count = 0.0, {}, 0, {} forward_step = hooks["forward_step"]

while True:
    profiling_flag = (args.profiling != -1 and args.iteration >= args.profiling)
    # Forward model for one step.
    if profiling_flag:
        torch.cuda.nvtx.range_push("forward")
    timers('forward').start()
    forward_ret = forward_step(data_iterator, model, args, timers, **kwargs)
    if isinstance(forward_ret, tuple):
        lm_loss, metrics = forward_ret
    else:
        lm_loss, metrics = forward_ret, {}
    timers('forward').stop()
    if profiling_flag:
        torch.cuda.nvtx.range_pop()

    # Check nan or inf in forward, preventing it from interfering loss scaler,
    # and all reduce metrics by the way
    if profiling_flag:
        torch.cuda.nvtx.range_push("loss_and_metrics")
    lm_loss_reduced = lm_loss.detach().clone()
    torch.distributed.all_reduce(lm_loss_reduced.data)
    lm_loss_reduced.data = lm_loss_reduced.data / args.world_size

    loss_checker = lm_loss_reduced
    for name in metrics:
        if not 'eval' in name:
            metrics[name] = metrics[name].detach().clone()
            if metrics[name].data.item() == -100:
                cnt = torch.zeros(1, dtype=torch.int64, device=metrics[name].data.device)
                metrics[name].data = torch.tensor(0., device=metrics[name].data.device)
            else:
                cnt = torch.ones(1, dtype=torch.int64, device=metrics[name].data.device)
            torch.distributed.all_reduce(metrics[name].data)
            torch.distributed.all_reduce(cnt)
            if cnt.item() == 0:
                metrics[name].data = torch.tensor(-100, device=metrics[name].data.device)
            else:
                metrics[name].data /= cnt.cpu().item() # args.world_size
            loss_checker = loss_checker + metrics[name]
    if loss_checker.isnan().any() or loss_checker.isinf().any():
        print_all('Skipping backward and optimizer step for nan or inf in forwarding metrics/loss!')
        return lm_loss.detach(), 1, metrics

    # Accumulate the statistics
    lm_loss_total += lm_loss_reduced
    for name in metrics:
        if name not in metrics_total:
            metrics_total[name] = torch.tensor(0.0, device=metrics[name].data.device)
        if name not in metrics_count:
            metrics_count[name] = 0
        if metrics[name].data.item() != -100:
            metrics_total[name] += metrics[name]
            metrics_count[name] += 1
    count += 1
    if profiling_flag:
        torch.cuda.nvtx.range_pop()

    if profiling_flag:
        torch.cuda.nvtx.range_push("backward")
    # Calculate gradients, reduce across processes, and clip.
    timers('backward').start()
    backward_step(optimizer, model, lm_loss, args, timers)
    timers('backward').stop()
    if profiling_flag:
        torch.cuda.nvtx.range_pop()
    # Update parameters.
    skipped_iter, complete = 0, False
    if profiling_flag:
        torch.cuda.nvtx.range_push("optimizer")
    timers('optimizer').start()
    if args.deepspeed:
        if model.is_gradient_accumulation_boundary():
            model.step()
            complete = True
            if not (args.fp16 and optimizer.overflow):
                lr_scheduler.step()
            else:
                skipped_iter = 1
        else:
            model.step()
    else:
        raise ValueError('Currently, we only support training with deepspeed.')
    timers('optimizer').stop()
    if profiling_flag:
        torch.cuda.nvtx.range_pop()
    if complete or single_step:
        break
lm_loss_total /= count
metrics_total = {key: torch.tensor(-100, device=metrics_total[key].data.device) if metrics_count[key] == 0 else value / metrics_count[key] for key, value in metrics_total.items()}
return lm_loss_total, skipped_iter, metrics_total

`

PR-Ryan avatar Sep 22 '24 09:09 PR-Ryan