lecam-gan icon indicating copy to clipboard operation
lecam-gan copied to clipboard

Low-shot training

Open xinyouduogao opened this issue 3 years ago • 6 comments

How to train the Lecam-gan on the low-shot image generation datasets,THX.

xinyouduogao avatar Jun 13 '21 13:06 xinyouduogao

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

hytseng0509 avatar Jun 23 '21 10:06 hytseng0509

Hi,

We modify the code of DiffAug to train on the low-shot image generation task.

Hi, would you mind sharing the source code of the low-shot generation experiments? It will help us a lot.

liang-hou avatar Aug 03 '21 03:08 liang-hou

Hello, I'd also try to implement lecam loss on DiffAug low shot. Could you share the train script please? Thanks!

aprilycliu avatar Sep 13 '21 11:09 aprilycliu

Hi,

I would add my vote to this discussion, it would be very helpful to have a look at the training script or the modified training.loss.py file. I tried to implement the method myself on top of the DiffAugm, but I still did not manage to reproduce the results from suppl. Table 4. Thanks in advance!

SushkoVadim avatar Sep 14 '21 11:09 SushkoVadim

This experiment only shows in the supplementary and does not belong to the main paper. We'd love to release the code but it may require additional approvals. We will try our best and see what we can do.

roadjiang avatar Sep 14 '21 18:09 roadjiang

Hi, Thanks a lot for answering! I understand that the clearing process for open sourcing can be time-consuming and burdensome. Potentially to simplify the answer, could I please ask you to share a comment on my attempts to reproduce the training? Perhaps, I did not know some implementation details that appear to be important. This can also be beneficial for others trying to reproduce the results for the low-shot training.

My modification to the DiffAugm was to add the lecam regularizatoin in the training.loss.py module.

  1. Particularly, I added a simple EMA tracker for both the real and fake logits to the StyleGAN2Loss Class:
self.val_ema_real = val_EMA()
self.val_ema_fake = val_EMA()

class val_EMA():
    def __init__(self, ema_decay=0.99):
        self.ema_decay = ema_decay
        self.mem_value = 0

    def add_step(self, cur_values):
        self.mem_value = self.ema_decay * self.mem_value + (1 - self.ema_decay) * cur_values.detach()

    def get_cur_val(self):
        return self.mem_value
  1. During training I add new logit values to the ema accumulation, and then add the regularization to the objective functions:
    # for fakes
    loss_emaCR_fake = 0
    if do_emaCR:
        self.val_ema_fake.add_step(gen_logits)  
        loss_emaCR_fake = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(self.val_ema_real.get_cur_val() - gen_logits)))
        
    with torch.autograd.profiler.record_function('Dgen_backward'):
        (loss_Dgen + loss_emaCR_fake).mean().mul(gain).backward()
    ....
    # for reals
    loss_emaCR_real = 0
    if do_emaCR:
        self.val_ema_real.add_step(real_logits)
        loss_emaCR_real = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(real_logits - self.val_ema_fake.get_cur_val())))
 
    with torch.autograd.profiler.record_function(name + '_backward'):
        (real_logits * 0 + loss_Dreal + loss_Dr1 + loss_emaCR_real).mean().mul(gain).backward()

I run the training for the same 300 kimg, I use self.cr_ema_lambda = 0.0001, self.ema_decay = 0.99, which corresponds to the description from the supplementary material. After the training is finished, I measure the following best FID across epochs:

Use LeCam CR? Metrics Animal Face - Cat Animal Face Dog Obama Panda Grumpy Cat
- reported 42.10 58.47 47.09 12.10 27.21
Yes reported 33.16 54.88 33.16 10.16 24.93
- reproduced 40.20 67.12 48.31 14.44 27.09
Yes reproduced 39.55 64.84 50.80 14.82 29.66

Thus, I am able to reproduce the original numbers from DiffAugm repository. However, the results after adding the lecam CR seem not to match to Table 4, this step is even harmful for 3/5 of the datasets.

It would be indeed very helpful if we figure out where lies my misunderstanding. Regards, Vadim

SushkoVadim avatar Sep 15 '21 10:09 SushkoVadim