lightweight-gan icon indicating copy to clipboard operation
lightweight-gan copied to clipboard

Multiclass training and inference

Open virilo opened this issue 4 years ago • 13 comments

I'd like to lightweight-gan for a multiclass dataset.

The idea is to train the GAN with multiclass.

And during the inference, ask the GAN for an image with multiples tags. I.e. generate an image tagged as 'boat', 'sunset' and 'people'

Is it possible with lightweight-gan?

virilo avatar Jan 20 '21 19:01 virilo

@virilo no, not at the moment, the architecure isn't class conditional

lucidrains avatar Jan 21 '21 03:01 lucidrains

@lucidrains how you think, does enough add info about tags into latents before generate images and into discimnator checks? Or it more difficult task?

Like

latents = torch.randn(batch_size, latent_dim).cuda(self.rank)
latents[0] = 0.0  # tag 1 is disabled
latents[1] = 1.0  # tag 2 is enabled
latents[2] = 0.8  # tag 3 is enabled as 80% prob

Dok11 avatar Jan 21 '21 11:01 Dok11

It is pretty straightforward I made a Conditional version in my checkout,

You simply attach a one hot vector size of the number of classes to the latent code and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Mut1nyJD avatar Jan 21 '21 16:01 Mut1nyJD

and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented? Losses just add up or it have other logic?

Dok11 avatar Jan 21 '21 16:01 Dok11

yea, it's doable! but i'm focused on Alphafold2 for this week and the next

I'll circle back to this in due time!

lucidrains avatar Jan 21 '21 16:01 lucidrains

and beef up the discriminator to output also a class prediction instead of just the binary and finally you add another CrossEntropy loss to the overall losses

Sounds little hard.. 🤔 can you show code or pseudo-code how it can be implemented?

Generator:

In the generator you simply have to add another parameter like number of classes

then change the self.initial_conv to

 nn.Sequential(
            nn.ConvTranspose2d(latent_dim+num_classes, latent_dim * 2, 4),
            norm_class(latent_dim * 2),
            nn.GLU(dim = 1)
        )

In the Discriminator you also should add the number of classes as new parameter. Other than that it is wise to split up the to_logits since it is best if the class and binary real/fake go through the same feature branch, so seperate the last conv layer where it reduces it to 1 channel out and add another conv layer, something like that

           self.to_logits = nn.Sequential(
                Blur(),
                nn.Conv2d(last_chan, last_chan, 3, stride = 2, padding = 1),
                nn.LeakyReLU(0.1),
            )
            self.realfakeConv = nn.Conv2d(last_chan, 1, 4)
            self.classConv = nn.Conv2d(last_chan,self.num_classes,4)

Obviously you have to change the code in the forward function where to_logits is called to

       logits = self.to_logits(x)
        out = self.realfakeConv(logits)
        out_class = self.classConv(logits)

And also return out_class from the forward

The rest is just a bit of pluming

Losses just add up or it have other logic?

Yes of course you could also add another weighting hyperparameter between fake/real loss and class loss

Mut1nyJD avatar Jan 22 '21 09:01 Mut1nyJD

Thank you very much @Mut1nyJD, i was just about to write a class conditional version by myself. I thought about implementing class conditional batch normalization and a projection discriminator as in BIGGAN etc because I had better experiences with that technique. Your way seems to be easier to implement though. Did u already train a model with that class conditional extension? How are the results? Would be cool to know before i try it out myself.

taucontrib avatar Feb 18 '21 10:02 taucontrib

@xnqio In my tests it works pretty well but I only used dataset where the number of classes is relatively small <20. Yes I like the simplicity of this method instead of projection and class embedding. An additional trick to do is to add an additional class the fake class and when you put examples from the generator to the discriminator network you mark them as fake class. That gives it an even stronger learning signal.

Mut1nyJD avatar Feb 20 '21 11:02 Mut1nyJD

@Mut1nyJD thanks for the tipp. Do you add the fake class additionally to real/fake classification or do you get rid of the real_fake conv = nn.Conv2d(chan, 1, ..) and replace it by a class_conv = nn.Conv2d(chan, self.number_classes + 1, ..)?

taucontrib avatar Feb 22 '21 10:02 taucontrib

@xnqio Yes I add the fake class additionally and no it is best to keep the binary classifier output in the discriminator as well but let them go though the same feature branch before hand so you have one additional output compared to the unconditioned version. And for that one you can do standard CrossEntropyLoss with num_classes+1.

Mut1nyJD avatar Feb 23 '21 23:02 Mut1nyJD

Alright got it @Mut1nyJD One last question: The standard CrossEntropyLoss doesn't work for multi label classification. This means that the generated images only belong to the fake class and not to the one they try to imitate, right?

taucontrib avatar Feb 24 '21 10:02 taucontrib

Yes for the output of the generator you set the labels as fake when training step of the discriminator for the generator you leave the real labels intact. Oh one thing I did notice but I have not verified if that is a general problem with this architecture or just with my conditional version but I've seen using AMP training seems to lead more often to mode collapse.

Mut1nyJD avatar Feb 26 '21 09:02 Mut1nyJD

If we now talks about tricks then I should put this link here =) https://github.com/soumith/ganhacks

Dok11 avatar Feb 26 '21 10:02 Dok11