Get loss=nan when finetune VAE
I found here cause nan: ldm/modules/losses/contperceptual.py
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
if last_layer is not None:
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
else:
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
d_weight = d_weight * self.discriminator_weight
return d_weight
you should try fp32 vae and optimizer
It's really amazing that you know I use adam8bit and fp16.
@eeyrw do u see any improvements after you finetuned the vae?
No. I have no sufficient GPU ram so fail to make further try.
@eeyrw i got nan too but not there. it was in https://github.com/CompVis/taming-transformers/blob/master/taming/modules/losses/lpips.py#L117. I solved nan by replacing that line with torch.sqrt(torch.sum(x**2,dim=1,keepdim=True) + eps).
@keyu-tian Nice eps trick improves numerical stability a lot 😀