latent-diffusion icon indicating copy to clipboard operation
latent-diffusion copied to clipboard

Loss simple and loss vlb

Open TuanDTr opened this issue 1 year ago • 2 comments

I don't fully understand how final loss is computed based on loss simple and loss vlb. It seems like both loss terms are measured with L1/L2 and are combined by different weighting terms. For loss vlb, I found the formula in the original DDPM that seems to match the implementation. But I am not clear about the weighting term for loss simple and why it is combined with loss vlb. Apologize if this question is too trivial ...

        loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f"{prefix}/loss_simple": loss_simple.mean()})

        logvar_t = self.logvar[t.to("cpu")].to(self.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        if self.learn_logvar:
            loss_dict.update({f"{prefix}/loss_gamma": loss.mean()})
            loss_dict.update({"logvar": self.logvar.data.mean()})

        loss = self.l_simple_weight * loss.mean()

        loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f"{prefix}/loss_vlb": loss_vlb})
        loss += self.original_elbo_weight * loss_vlb
        loss_dict.update({f"{prefix}/loss": loss})

TuanDTr avatar Feb 16 '24 09:02 TuanDTr

just find that original_elbo_weight is set to zero by default

kaka45inablink avatar Mar 25 '24 03:03 kaka45inablink