latent-diffusion
latent-diffusion copied to clipboard
Loss simple and loss vlb
I don't fully understand how final loss is computed based on loss simple and loss vlb. It seems like both loss terms are measured with L1/L2 and are combined by different weighting terms. For loss vlb, I found the formula in the original DDPM that seems to match the implementation. But I am not clear about the weighting term for loss simple and why it is combined with loss vlb. Apologize if this question is too trivial ...
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})
logvar_t = self.logvar[t.to("cpu")].to(self.device)
loss = loss_simple / torch.exp(logvar_t) + logvar_t
if self.learn_logvar:
loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
loss_dict.update({"logvar": self.logvar.data.mean()})
loss = self.l_simple_weight * loss.mean()
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
loss += self.original_elbo_weight * loss_vlb
loss_dict.update({f"{prefix}/loss": loss})
just find that original_elbo_weight is set to zero by default