PyTorch-GAN icon indicating copy to clipboard operation
PyTorch-GAN copied to clipboard

A bug in AAE implementation ?

Open Shentao-YANG opened this issue 3 years ago • 2 comments

Hi,

Thanks for this repo!

The following code from the aae.py file line 197 - 199 confuse me a bit.

real_loss = adversarial_loss(discriminator(z), valid)     
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)        
d_loss = 0.5 * (real_loss + fake_loss)

AFAIK discriminator(z) should constitute fake examples while encoded_imgs.detach()) true examples. The goal of the discriminator is to correctly classify the noise z as fake embedding and encoded_imgs.detach() as valid. Hence, I suggest the following modification to this block of codes:

real_loss = adversarial_loss(discriminator(encoded_imgs.detach()), valid)        
fake_loss = adversarial_loss(discriminator(z), fake)     
d_loss = 0.5 * (real_loss + fake_loss)

Please let me know if I misunderstanding anything.

Shentao-YANG avatar Jun 13 '21 04:06 Shentao-YANG

Hi, This is what i understood, Based on the paper the role of the discriminator is to predict whether a sample is from the hidden latent code of auto encoder (encoded_imgs.detach())) or from a sampled distribution (z). So the training criterion here is to match the posterior distribution of latent space of autoencoder to that of the arbitrary prior distribution ( here z with normal distribution).

so by training like this

real_loss = adversarial_loss(discriminator(z), valid) 
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)

the encoder will generate the latent vectors close to our required prior ( closer to z). So after training, generating from any part of our prior z will produce a meaning full sample.

sanjumsanthosh avatar Jul 28 '21 17:07 sanjumsanthosh

Hi, This is what i understood, Based on the paper the role of the discriminator is to predict whether a sample is from the hidden latent code of auto encoder (encoded_imgs.detach())) or from a sampled distribution (z). So the training criterion here is to match the posterior distribution of latent space of autoencoder to that of the arbitrary prior distribution ( here z with normal distribution).

so by training like this

real_loss = adversarial_loss(discriminator(z), valid) 
fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)

the encoder will generate the latent vectors close to our required prior ( closer to z). So after training, generating from any part of our prior z will produce a meaning full sample.

trank you

zhangslab avatar Nov 12 '22 09:11 zhangslab