pytorch-lightning-vae
pytorch-lightning-vae copied to clipboard
quality of image is not as good as the one in the blog
Hi William,
I am following your script to build cifar10 model, I checked the image after 20 epochs, which is not as good as the one, not sure if you can share any hint on what is wrong in my experiment (I did not change the script actually), thanks a lot!

This is yours at https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed

did you solve the problem? It happens the same thing on mine.
so I want to add some more details to this issue. I tried to edit slightly the code as I thought there must be something weird happening with the normalization process.
num_preds=16
mu=torch.zeros([256])
std=torch.ones([256])
p=torch.distributions.Normal(torch.zeros_like(mu),torch.ones_like(std))
z=p.rsample((num_preds,))
with torch.no_grad():
pred=vae.decoder(z.to(vae.device)).cpu()
normalize=cifar10_normalization()
mean,std=np.array(normalize.mean), np.array(normalize.std)
# this one prints out (0.4466666666666667 0.4913725490196078)
print(mean.min(), mean.max())
img1=make_grid(pred).permute(1,2,0).numpy()*std+mean
img2=make_grid(pred).permute(1,2,0).numpy()
img3=make_grid(pred, normalize=True).permute(1,2,0).numpy()
f, ax = plt.subplots(3, figsize=(15,15))
f.figsize=(100,60)
ax[0].set_title("base")
ax[0].imshow((img1*255).astype('uint8'))
ax[1].set_title("no normalization")
ax[1].imshow((img2*255).astype('uint8'))
ax[2].set_title("norm via torchvision")
ax[2].imshow((img3*255).astype('uint8'))
this is the result, after training for 20 epochs:

The mean min and max value are too near eachother and ideally we want them to be in the range of 0, 1. So... if we replace
old = make_grid(pred).permute(1,2,0).numpy()*std+mean
new = make_grid(pred).permute(1,2,0).numpy()
the issue is gone. But doing this way we are not actually denormalizing anything and also we risk some clipping. On the other hand, make_grid automatically does an internal normalization based on the min and max value if we pass the parameter normalize=True and brings everything in the 0,1 range but it does this blindly on all the pictures, meaning that somewhere there must be full white and full black.