lecam-gan
lecam-gan copied to clipboard
Low-shot training
How to train the Lecam-gan on the low-shot image generation datasets,THX.
Hi,
We modify the code of DiffAug to train on the low-shot image generation task.
Hi, would you mind sharing the source code of the low-shot generation experiments? It will help us a lot.
Hello, I'd also try to implement lecam loss on DiffAug low shot. Could you share the train script please? Thanks!
Hi,
I would add my vote to this discussion, it would be very helpful to have a look at the training script or the modified training.loss.py file. I tried to implement the method myself on top of the DiffAugm, but I still did not manage to reproduce the results from suppl. Table 4. Thanks in advance!
This experiment only shows in the supplementary and does not belong to the main paper. We'd love to release the code but it may require additional approvals. We will try our best and see what we can do.
Hi, Thanks a lot for answering! I understand that the clearing process for open sourcing can be time-consuming and burdensome. Potentially to simplify the answer, could I please ask you to share a comment on my attempts to reproduce the training? Perhaps, I did not know some implementation details that appear to be important. This can also be beneficial for others trying to reproduce the results for the low-shot training.
My modification to the DiffAugm was to add the lecam regularizatoin in the training.loss.py
module.
- Particularly, I added a simple EMA tracker for both the real and fake logits to the StyleGAN2Loss Class:
self.val_ema_real = val_EMA()
self.val_ema_fake = val_EMA()
class val_EMA():
def __init__(self, ema_decay=0.99):
self.ema_decay = ema_decay
self.mem_value = 0
def add_step(self, cur_values):
self.mem_value = self.ema_decay * self.mem_value + (1 - self.ema_decay) * cur_values.detach()
def get_cur_val(self):
return self.mem_value
- During training I add new logit values to the ema accumulation, and then add the regularization to the objective functions:
# for fakes
loss_emaCR_fake = 0
if do_emaCR:
self.val_ema_fake.add_step(gen_logits)
loss_emaCR_fake = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(self.val_ema_real.get_cur_val() - gen_logits)))
with torch.autograd.profiler.record_function('Dgen_backward'):
(loss_Dgen + loss_emaCR_fake).mean().mul(gain).backward()
....
# for reals
loss_emaCR_real = 0
if do_emaCR:
self.val_ema_real.add_step(real_logits)
loss_emaCR_real = self.cr_ema_lambda * torch.mean(torch.square(torch.nn.functional.relu(real_logits - self.val_ema_fake.get_cur_val())))
with torch.autograd.profiler.record_function(name + '_backward'):
(real_logits * 0 + loss_Dreal + loss_Dr1 + loss_emaCR_real).mean().mul(gain).backward()
I run the training for the same 300 kimg, I use self.cr_ema_lambda = 0.0001
, self.ema_decay = 0.99
, which corresponds to the description from the supplementary material. After the training is finished, I measure the following best FID across epochs:
Use LeCam CR? | Metrics | Animal Face - Cat | Animal Face Dog | Obama | Panda | Grumpy Cat |
---|---|---|---|---|---|---|
- | reported | 42.10 | 58.47 | 47.09 | 12.10 | 27.21 |
Yes | reported | 33.16 | 54.88 | 33.16 | 10.16 | 24.93 |
- | reproduced | 40.20 | 67.12 | 48.31 | 14.44 | 27.09 |
Yes | reproduced | 39.55 | 64.84 | 50.80 | 14.82 | 29.66 |
Thus, I am able to reproduce the original numbers from DiffAugm repository. However, the results after adding the lecam CR seem not to match to Table 4, this step is even harmful for 3/5 of the datasets.
It would be indeed very helpful if we figure out where lies my misunderstanding. Regards, Vadim