magvit2-pytorch icon indicating copy to clipboard operation
magvit2-pytorch copied to clipboard

Discriminator loss converges to zero early in training

Open jpfeil opened this issue 1 year ago • 9 comments

I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN: https://api.wandb.ai/links/pfeiljx/f7wdueh0

Do you have any suggestions for improving training?

I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?

I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?

After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.

I'm just brainstorming, so feel free to address or ignore any of these comments.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()

jpfeil avatar Nov 21 '23 22:11 jpfeil

@jpfeil can you screenshot the paper section where they propose delaying the discriminator training? (and link the paper too)

lucidrains avatar Nov 21 '23 23:11 lucidrains

@jpfeil do you have adversarial_loss_weight greater than 0.? also try another run where your perceptual_loss_weight is 0.1

lucidrains avatar Nov 21 '23 23:11 lucidrains

Thanks @lucidrains. I'll try again with those parameters. I saw it in the taming implementation here: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/vqperceptual.py#L51

jpfeil avatar Nov 21 '23 23:11 jpfeil

@jpfeil welp.. whatever Robin and Patrick does goes; they are the best in the world.

let me add that

lucidrains avatar Nov 21 '23 23:11 lucidrains

@jpfeil ok, added that same functionality here. try removing the learning rate schedule in your next run too, shouldn't need it for something this easy

lucidrains avatar Nov 21 '23 23:11 lucidrains

@jpfeil you don't happen to have relatives in Massachusetts, do you?

lucidrains avatar Nov 21 '23 23:11 lucidrains

@lucidrains Nice. Let me try it out again. No, I don't have any relatives in Massachusetts. Did you meet someone with the last name Pfeil?

jpfeil avatar Nov 21 '23 23:11 jpfeil

yea, I knew someone back in high school with the Pfeil family name. Tragedy struck and they moved away though. You are the second Pfeil I've met!

lucidrains avatar Nov 21 '23 23:11 lucidrains

That's amazing. It's not a common name. Sorry to hear about your friend.

jpfeil avatar Nov 21 '23 23:11 jpfeil