examples
examples copied to clipboard
Why don't we use MSE as a reconstruction loss for VAE ?
Hi,
I am wondering if there is a theoretical reason for using BCE as a reconstruction loss for variation auto-encoders ? Can't we simply use MSE or norm-based reconstruction loss instead ?
Best Regards
+1
According to the original VAE paper[1], BCE is used because the decoder is implemented by MLP+Sigmoid which can be viewed as a 'Bernoulli distribution'. You can use MSE if you implement a Gaussian decoder. Take the following pseudocode for an example,
mu = MLP(z)
sigma = MLP(z)
reconstruction = Gaussian(mu, sigma)
Ref
[1] https://arxiv.org/abs/1312.6114
because the decoder is implemented by MLP+Sigmoid which can be viewed as a 'Bernoulli distribution'.
Does this mean that in order to model "continuous" distribution like Gaussian then we should not use sigmoid as the output layer and replace it with tanh or even flattened conv layer for example ?
Yes, that is the idea. However, I don't think tanh would be an appropriate choice for the output layer in this case, as we are fitting values in [0,1] rather than [-1,1]. In fact, an error will jump out if an negative input is feed into torch.nn.functional.binary_cross_entropy().
Then using conv layer and flatten it could be a possible solution for modelling output layer ?
because the decoder is implemented by MLP+Sigmoid which can be viewed as a 'Bernoulli distribution'.
Does this mean that in order to model "continuous" distribution like Gaussian then we should not use sigmoid as the output layer and replace it with tanh or even flattened conv layer for example ?
If I change https://github.com/pytorch/examples/blob/master/vae/main.py#L60 from .sigmoid() to .tanh(), and https://github.com/pytorch/examples/blob/master/vae/main.py#L74 from BCE to MSE, will that make this VAE to try a Gaussian Reconstruction and work for [-1, 1]? I would appreciate any input.
@muammar To approximate a gaussian posterior, it usually works fine to use no activation function in the last layer and interpret the output as mean for a normal distribution. If we assume a constant variance for the posterior, we naturally end up with the MSE as loss function. An alternative option is proposed by An et al.. We can duplicate the output layer of the decoder to model the mean and variance of the normal distribution and then optimize the negative log likelihood.