vits
vits copied to clipboard
About KL Divergence loss
About KL Divergence loss in losses.py
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
z_p = z_p.float()
logs_q = logs_q.float()
m_p = m_p.float()
logs_p = logs_p.float()
z_mask = z_mask.float()
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
kl = torch.sum(kl * z_mask)
l = kl / torch.sum(z_mask)
return l
KL Loss:
kl = logs_p - logs_q - 0.5
kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p)
The formula for calculating Gaussian divergence is:
But the code is:
loss in code $σ_1^2$
May I ask why $σ_1^2$ is missing?
Question quoted in
https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/issues/469
Same as https://github.com/jaywalnut310/vits/issues/6#issuecomment-861903556