tutorials
tutorials copied to clipboard
Repeated calls to generator and discriminator's forward in GAN tutorial
📚 Documentation
In the training_step of the GAN(L.LightningModule) the generator and discriminator forward are called several times on the same input. Obviously this slows down the training because more computation is required. I wonder if we could just reuse the results of the first call. After all, the toggle_optimiser
/untoggle_optimiser
functions should make it safe, right?
For the generator:
- First call: to log images
self.generated_imgs = self(z)
- Second call: Inside the generator optimization
self.discriminator(self(z))
- Third call: Inside the discriminator optimization
self.discriminator(self(z).detach())
For the discriminator:
- First call: Inside the generator optimization
self.discriminator(self(z))
- Second call: Inside the discriminator optimization
self.discriminator(self(z).detach())
cc @borda
Still on the basic GAN tutorial, I spotted a few more track of improvements:
-
on_validation_epoch_end is ignored because
validation_step
is not defined. - After training, the quality of the samples is still very poor. I understand that the code is just an entry point for newcomers, but this poor performance makes the user doubtful about the tutorial. We could easily improve this by using a small convolutional architecture for both the generator and the discriminator without complicating the code. (e.g. Generator, Discriminator )
-
add_image called with the constant argument
global_step=0
is overwriting the results of previous epochs. - The repeated calls to the generator and discriminator could be simplified to improve training speed (thanks to
retain_graph=True
).
I am willing to implement and submit a PR if you find this helpful :smiley:
Hello @jhauret, I apologize for the late reply, it was not in good shape until a few days ago, so would you be interested in sending PR with the fix? :rabbit:
Hi @Borda, I hope you're feeling better now. Yes, I'll do it the first week of August!