PyTorch-VAE icon indicating copy to clipboard operation
PyTorch-VAE copied to clipboard

Problems of IWAE ELBO Loss

Open GloryyrolG opened this issue 4 years ago • 5 comments

Hi Anand and all,

As weighting of samples, weight should be detached from the current computational graph for the expected optimization objective, right? See https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/models/iwae.py#L155

GloryyrolG avatar Jun 03 '21 07:06 GloryyrolG

Actually I saw there is a detach statement but in the annotation. https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/models/iwae.py#L152

GloryyrolG avatar Jun 03 '21 07:06 GloryyrolG

Besides, as the original paper said, "Vanilla VAE separated out the KL divergence in the bound in order to achieve a simpler and lower-variance update. Unfortunately, no analogous trick applies for k > 1" (Y. Burda et al., 2016). How are we still able to compute KL Divergence? https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/models/iwae.py#L152

GloryyrolG avatar Jun 03 '21 08:06 GloryyrolG

I also found this change very suspicious.

In the original paper Eq 14, we have:

Capture

this obviously requires the grad w to be detached. or else the grad will be equals to:

\Sum (w * \nabla \log(ELBO) + \nabla w * \log(ELBO))

which has an additional term due to taking derivative wrt to \sum(w * ELBO)

tongdaxu avatar Mar 15 '22 08:03 tongdaxu

Besides, as the original paper said, "Vanilla VAE separated out the KL divergence in the bound in order to achieve a simpler and lower-variance update. Unfortunately, no analogous trick applies for k > 1" (Y. Burda et al., 2016). How are we still able to compute KL Divergence?

https://github.com/AntixK/PyTorch-VAE/blob/8700d245a9735640dda458db4cf40708caf2e77f/models/iwae.py#L152

I think you are also right here, the SGVB 2 estimator separate out KL divergence out of the monte carlo sampling of reparameterized noise. Here, we should use SGVB 1 instead and use Monte Carlo to compute the whole log p(x, y) - q(y|h).

tongdaxu avatar Mar 15 '22 08:03 tongdaxu

Kindly refers to PR: https://github.com/AntixK/PyTorch-VAE/pull/53

tongdaxu avatar Mar 15 '22 10:03 tongdaxu