KAIR icon indicating copy to clipboard operation
KAIR copied to clipboard

GANLoss for G network: goal should be to have D output >=0.5, not 1?

Open jxtps opened this issue 2 years ago • 1 comments

In model_gan.py we have:

            if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
                pred_g_fake = self.netD(self.E)
                D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True)

But when D & G have converged, pred_g_fake is likely to be close to midway between fake & real. This means that even for pixels / examples where G has done a great job, there will still be gradients yanking it around.

This is somewhat complicated by the fact that some GAN losses use BCEWithLogitsLoss (= they sigmoid the input) and lsgan uses MSELoss (which doesn't sigmoid the input), so the meaning of the input differs between them.

Anyway, something like:

            if self.opt['train']['gan_type'] in ['gan', 'lsgan', 'wgan', 'softplusgan']:
                pred_g_fake = self.netD(self.E)
                if self.opt['train']['gan_type'] == 'gan':
                    pred_g_fake = torch.where(pred_g_fake > 0, 1e3, pred_g_fake)  # sigmoid(1e3) = 1.0
                elif self.opt['train']['gan_type'] == 'lsgan':
                    pred_g_fake = (pred_g_fake * 2).min(torch.scalar_tensor(1))
                D_loss = self.D_lossfn_weight * self.D_lossfn(pred_g_fake, True)

This is just meant as an illustration, not a pull request. It obviously does very different things for gan (it cuts at the midpoint, but doesn't stretch, because it's unclear how to stretch prior to the sigmoid) vs lsgan (where it stretches [0-0.5] => [0-1] - it gets more complicated if the real label != 1 or fake label != 0).

But it should enable the gradients to "calm down" when G has done a good job and D is maximally confused.

I guess what I'm saying is that while the objective for D should indeed be to output 0 or 1 (ignoring the sigmoid), the goal for G is less to have D output 1 and more to have it output >=0.5.

Thoughts?

jxtps avatar Apr 28 '22 04:04 jxtps

I guess this would kind of mimic how the D optimization uses both the fake & the real data to have the gradients cancel out where they're the same ( https://github.com/cszn/KAIR/blob/master/models/model_gan.py#L249 ):

            # real
            pred_d_real = self.netD(self.H)                # 1) real data
            l_d_real = self.D_lossfn(pred_d_real, True)
            l_d_real.backward()
            # fake
            pred_d_fake = self.netD(self.E.detach().clone()) # 2) fake data, detach to avoid BP to G
            l_d_fake = self.D_lossfn(pred_d_fake, False)
            l_d_fake.backward()

jxtps avatar Apr 28 '22 04:04 jxtps