PyTorch-VAE icon indicating copy to clipboard operation
PyTorch-VAE copied to clipboard

Setting for gamma and max_capacity in Beta_VAE

Open WNZhao1988 opened this issue 2 years ago • 2 comments

Hi, When using beta_VAE for my own dataset, I'm not sure how I could set values for gamma and max_capacity. Should I just use the default one? Or is there any rule for setting them? Does anyone have a sense or explanation of this? Thank you!

WNZhao1988 avatar Mar 21 '22 09:03 WNZhao1988

Hi! I've been working with B-VAE and the controlled capacity B-VAE (the one you're referencing) quite a bit lately, and I found myself asking this exact question.

I found this repository especially helpful in figuring out the parameters. Long story short: reconstruction loss is summed across each image and the sum is then averaged across your batch. The implementation in this repository computes the pixel-wise average which doesn't work well with the intended parameters. Modifying the function results in a scale that algins properly with the Understanding disentangling in β-VAE, Burgess et al., arxiv:1804.03599, 2018 paper. And the default values will have a much better time working out of the box.

Modified function:

    def loss_function(self,
                      *args,
                      **kwargs) -> dict:
        self.num_iter += 1
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        kld_weight = kwargs['M_N']  # Account for the minibatch samples from the dataset

       # original: recons_loss = F.mse_loss(recons, input)
       # modified:
        recons_loss = F.mse_loss(recons, input, reduction="sum").div(input.shape[0])

        kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        if self.loss_type == 'H': # https://openreview.net/forum?id=Sy2fzU9gl
            loss = recons_loss + self.beta * kld_weight * kld_loss
        elif self.loss_type == 'B': # https://arxiv.org/pdf/1804.03599.pdf
            self.C_max = self.C_max.to(input.device)
            C = torch.clamp(self.C_max/self.C_stop_iter * self.num_iter, 0, self.C_max.data[0])
            loss = recons_loss + self.gamma * kld_weight* (kld_loss - C).abs()
        else:
            raise ValueError('Undefined loss type.')

        return {'loss': loss, 'Reconstruction_Loss':recons_loss, 'KLD':kld_loss}

Finally, as for good gamma and max_capacity values: Depending on your application, you can get better reconstruction by increasing the capacity available to the model but you'll also end up with a latent space that has less regularization which may be undesirable for what you're aiming for. KLD Capacity is summed across the latent dimension, so if you've got a latent space of size 64, and your capacity was 32, each of your latent distributions would be able to diverge (on average) by 0.5 each. Or a few dimensions may diverge greatly to take up the capacity available. I'd recommend starting with values from 1/8 to 2x your latent size and visualizing how your losses differ and what effect it has on your reconstructions.

Gamma on the other hand is dataset dependent. You can run an experiment with a value at i.e. 1000 and work out what % of your loss corresponds to the kld capacity component and tune accordingly. If you're working with 3 channeled images that have been scaled correctly to suit the model input, then 1000 -> 10000 may be a good starting range.

Very verbose reply but I hope that helps.

mjdall avatar Mar 25 '22 01:03 mjdall

@mjdall Could you explain the self.loss_type == 'B' loss and what are the clamp, C_max and C_stop_iter parameters are about from the paper? On another note taking random input and recon, The modified recon you shared generates loss of ~2200 whereas the KLD is around ~4. Even if the $\beta$ = 25 which is not sufficient to balance the high valued reconstruction loss compared to low KLD.
Any thoughts on that?

ranabanik avatar Mar 11 '24 17:03 ranabanik