latent-diffusion icon indicating copy to clipboard operation
latent-diffusion copied to clipboard

Get loss=nan when finetune VAE

Open eeyrw opened this issue 3 years ago • 7 comments

I found here cause nan: ldm/modules/losses/contperceptual.py

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

eeyrw avatar Oct 29 '22 06:10 eeyrw

you should try fp32 vae and optimizer

AlexWortega avatar Nov 14 '22 14:11 AlexWortega

It's really amazing that you know I use adam8bit and fp16.

eeyrw avatar Nov 15 '22 05:11 eeyrw

@eeyrw do u see any improvements after you finetuned the vae?

keyu-tian avatar Sep 01 '23 21:09 keyu-tian

No. I have no sufficient GPU ram so fail to make further try.

eeyrw avatar Sep 06 '23 01:09 eeyrw

@eeyrw i got nan too but not there. it was in https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py#L117. I solved nan by replacing that line with torch.sqrt(torch.sum(x**2,dim=1,keepdim=True) + eps).

keyu-tian avatar Oct 08 '23 07:10 keyu-tian

@keyu-tian Nice eps trick improves numerical stability a lot 😀

eeyrw avatar Oct 08 '23 09:10 eeyrw