pytorch-lightning-vae icon indicating copy to clipboard operation
pytorch-lightning-vae copied to clipboard

quality of image is not as good as the one in the blog

Open kelvinqin opened this issue 4 years ago • 3 comments

Hi William, I am following your script to build cifar10 model, I checked the image after 20 epochs, which is not as good as the one, not sure if you can share any hint on what is wrong in my experiment (I did not change the script actually), thanks a lot! 2021-02-06_02-22

kelvinqin avatar Feb 05 '21 18:02 kelvinqin

This is yours at https://towardsdatascience.com/variational-autoencoder-demystified-with-pytorch-implementation-3a06bee395ed 2021-02-06_02-43

kelvinqin avatar Feb 05 '21 18:02 kelvinqin

did you solve the problem? It happens the same thing on mine.

A7F avatar Apr 26 '21 12:04 A7F

so I want to add some more details to this issue. I tried to edit slightly the code as I thought there must be something weird happening with the normalization process.

num_preds=16
mu=torch.zeros([256])
std=torch.ones([256])
p=torch.distributions.Normal(torch.zeros_like(mu),torch.ones_like(std))
z=p.rsample((num_preds,))
with torch.no_grad():
  pred=vae.decoder(z.to(vae.device)).cpu()

normalize=cifar10_normalization()
mean,std=np.array(normalize.mean), np.array(normalize.std)

# this one prints out (0.4466666666666667 0.4913725490196078)
print(mean.min(), mean.max())

img1=make_grid(pred).permute(1,2,0).numpy()*std+mean

img2=make_grid(pred).permute(1,2,0).numpy()

img3=make_grid(pred, normalize=True).permute(1,2,0).numpy()

f, ax = plt.subplots(3, figsize=(15,15))
f.figsize=(100,60)
ax[0].set_title("base")
ax[0].imshow((img1*255).astype('uint8'))
ax[1].set_title("no normalization")
ax[1].imshow((img2*255).astype('uint8'))
ax[2].set_title("norm via torchvision")
ax[2].imshow((img3*255).astype('uint8'))

this is the result, after training for 20 epochs: test

The mean min and max value are too near eachother and ideally we want them to be in the range of 0, 1. So... if we replace

old = make_grid(pred).permute(1,2,0).numpy()*std+mean
new = make_grid(pred).permute(1,2,0).numpy()

the issue is gone. But doing this way we are not actually denormalizing anything and also we risk some clipping. On the other hand, make_grid automatically does an internal normalization based on the min and max value if we pass the parameter normalize=True and brings everything in the 0,1 range but it does this blindly on all the pictures, meaning that somewhere there must be full white and full black.

A7F avatar May 03 '21 14:05 A7F