hierarchical KL loss
Hi! Very Impressive work, thanks for sharing! I have a question regarding the hierarchical KL loss. As in the original paper, the hierarchical kl loss is stated as:
∑_L [KL(q(z_l | x, z_(l-1)) || p(z_l | z_(l-1)))]
,between encoder and decoder.
I am wording why did you model the KL loss between p(z_l | x, z_(l-1)) and p(z_l | z_(l-1)), which both are from decoder?
mu, log_var = self.condition_z[i](decoder_out).chunk(2, dim=1)
delta_mu, delta_log_var = self.condition_xz[i](torch.cat([xs[i], decoder_out], dim=1)).chunk(2, dim=1)
kl_losses.append(kl_2(delta_mu, delta_log_var, mu, log_var))
Please let me know if there are any misunderstandings. Thanks a lot in advance!:)
Maybe it's too late, but you can check this issue. In fact, what the author denotes as $p(z_l|x, z_(l-1))$ is exactly the inference model $q$, since it's only the inference model that conditionally dependent on the input $x$.