lightweight-gan
lightweight-gan copied to clipboard
Multiclass training and inference
I'd like to lightweight-gan for a multiclass dataset.
The idea is to train the GAN with multiclass.
And during the inference, ask the GAN for an image with multiples tags. I.e. generate an image tagged as 'boat', 'sunset' and 'people'
Is it possible with lightweight-gan?
@virilo no, not at the moment, the architecure isn't class conditional
@lucidrains how you think, does enough add info about tags into latents before generate images and into discimnator checks? Or it more difficult task?
Like
latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
latents[0] = 0.0 # tag 1 is disabled
latents[1] = 1.0 # tag 2 is enabled
latents[2] = 0.8 # tag 3 is enabled as 80% prob
It is pretty straightforward I made a Conditional version in my checkout,
You simply attach a one hot vector size of the number of classes to the latent code and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses
and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses
Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented? Losses just add up or it have other logic?
yea, it's doable! but i'm focused on Alphafold2 for this week and the next
I'll circle back to this in due time!
and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses
Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented?
Generator:
In the generator you simply have to add another parameter like number of classes
then change the self.initial_conv to
nn.Sequential(
nn.ConvTranspose2d(latent_dim+num_classes, latent_dim * 2, 4),
norm_class(latent_dim * 2),
nn.GLU(dim = 1)
)
In the Discriminator you also should add the number of classes as new parameter. Other than that it is wise to split up the to_logits since it is best if the class and binary real/fake go through the same feature branch, so seperate the last conv layer where it reduces it to 1 channel out and add another conv layer, something like that
self.to_logits = nn.Sequential(
Blur(),
nn.Conv2d(last_chan, last_chan, 3, stride = 2, padding = 1),
nn.LeakyReLU(0.1),
)
self.realfakeConv = nn.Conv2d(last_chan, 1, 4)
self.classConv = nn.Conv2d(last_chan,self.num_classes,4)
Obviously you have to change the code in the forward function where to_logits is called to
logits = self.to_logits(x)
out = self.realfakeConv(logits)
out_class = self.classConv(logits)
And also return out_class from the forward
The rest is just a bit of pluming
Losses just add up or it have other logic?
Yes of course you could also add another weighting hyperparameter between fake/real loss and class loss
Thank you very much @Mut1nyJD, i was just about to write a class conditional version by myself. I thought about implementing class conditional batch normalization and a projection discriminator as in BIGGAN etc because I had better experiences with that technique. Your way seems to be easier to implement though. Did u already train a model with that class conditional extension? How are the results? Would be cool to know before i try it out myself.
@xnqio In my tests it works pretty well but I only used dataset where the number of classes is relatively small <20. Yes I like the simplicity of this method instead of projection and class embedding. An additional trick to do is to add an additional class the fake class and when you put examples from the generator to the discriminator network you mark them as fake class. That gives it an even stronger learning signal.
@Mut1nyJD thanks for the tipp. Do you add the fake class additionally to real/fake classification or do you get rid of the real_fake conv = nn.Conv2d(chan, 1, ..) and replace it by a class_conv = nn.Conv2d(chan, self.number_classes + 1, ..)?
@xnqio Yes I add the fake class additionally and no it is best to keep the binary classifier output in the discriminator as well but let them go though the same feature branch before hand so you have one additional output compared to the unconditioned version. And for that one you can do standard CrossEntropyLoss with num_classes+1.
Alright got it @Mut1nyJD One last question: The standard CrossEntropyLoss doesn't work for multi label classification. This means that the generated images only belong to the fake class and not to the one they try to imitate, right?
Yes for the output of the generator you set the labels as fake when training step of the discriminator for the generator you leave the real labels intact. Oh one thing I did notice but I have not verified if that is a general problem with this architecture or just with my conditional version but I've seen using AMP training seems to lead more often to mode collapse.
If we now talks about tricks then I should put this link here =) https://github.com/soumith/ganhacks