LinFusion icon indicating copy to clipboard operation
LinFusion copied to clipboard

Training loss NaN

Open NSun-S opened this issue 1 year ago • 2 comments

Thanks for your awesome work. I'm trying to reproduce your results for distilling SD-XL. I ran bash examples/train/distill_xl.sh on an 8 - GPU machine. It has been running normally for 25 epochs, with more than 260,000 steps. However, the loss has always been NaN, as shown below:

step_loss: nan, step_loss_noise: nan, step_loss_kd: nan, step_loss_feat: nan

The only modification I made was changing certain lines to ensure the script runs properly. The modified code is as follows:

# Convert images to latent space
with torch.no_grad():
    latents = vae.encode(
        batch["image"].to(accelerator.device, dtype=weight_dtype)
    ).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    latents = latents.to(accelerator.device, dtype=weight_dtype)

Are there any parameters that should be adjusted? Could you provide your training loss curve or training log? Looking forward to your reply.

NSun-S avatar Nov 15 '24 03:11 NSun-S

Dear @NSun-S ,

Thanks for your interest in our work! We actually only run 100,000 steps and have not run so long in fact. If there is a problem for a longer training, I suggest trying bfloat16 data type or first training without loss_kd and loss_feat for about 50,000 steps and then adding these loss terms.

Please let us know if the problem persists.

Huage001 avatar Nov 17 '24 15:11 Huage001

Dear @NSun-S ,

Thanks for your interest in our work! We actually only run 100,000 steps and have not run so long in fact. If there is a problem for a longer training, I suggest trying bfloat16 data type or first training without loss_kd and loss_feat for about 50,000 steps and then adding these loss terms.

Please let us know if the problem persists.

Thanks for your reply. I have attempted to use bf16, and it seems that all three kinds of losses are computed normally.

I will check the performance after training :)

NSun-S avatar Nov 19 '24 06:11 NSun-S