diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Adding VQGAN Training script

Open isamu-isozaki opened this issue 8 months ago • 35 comments

What does this PR do?

This is a vqgan training script ported from taming-transformers and from lucidrian's muse-maskgit repo here and open-muse. I'm planning to test this on the cifar10 dataset to confirm it works

Some steps missing/need confirmation are

  • [x] Confirm einops and timm can be external dependencies. If not convert these ops to native pytorch
  • [x] Test on cifar10
  • [x] Add in test to test_models_vq and test_models_vae for the slight modification

Fixes #4702

Before submitting

  • [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

isamu-isozaki avatar Oct 23 '23 01:10 isamu-isozaki

Once confirmed it works with cifar10 will remove the draft part

isamu-isozaki avatar Oct 23 '23 01:10 isamu-isozaki

I was able to start training this script. And I removed the einops dependencies. The only additional dependency so far is timm. I plan to run this overnight on cifar with 128 image resolution and then remove the draft from this pr. Also let me know if anyone knows a good VQModel config that's easy to train/fast

isamu-isozaki avatar Oct 30 '23 02:10 isamu-isozaki

Ok! Training seems to work. Here's a wandb run on cifar 10. In 6gb vram, command to run this is

accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb

For the validation images, they will be shown like so for each validation image provided. The left is the input image and the right is the generated image original vs generated

The remaining parts that I can think of are

  • [x] Make log_validation support trackers other than wandb
  • [x] Make tqdm updates similar to other examples

I did find a bug where global step doesn't seem to go above 3000 but once that is fixed I'll open for review

isamu-isozaki avatar Oct 31 '23 13:10 isamu-isozaki

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

The main logic is done so I think it's ready for review. For the 3000 step bug I'm currently running training to see if it happens again after the fixes.

isamu-isozaki avatar Oct 31 '23 15:10 isamu-isozaki

Ok! Seems like it was a hardware issue(I think). Got steps 3100. Script should be ready for review.

isamu-isozaki avatar Nov 01 '23 01:11 isamu-isozaki

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Nov 26 '23 15:11 github-actions[bot]

Hi there, what is the current status of this PR? It seems that everything works well. Will this be merged?

yqy2001 avatar Feb 24 '24 03:02 yqy2001

ah let me resolve merge conflicts in morning

isamu-isozaki avatar Feb 27 '24 03:02 isamu-isozaki

Please add documentation about training that I am interested in that can validate training scripts @isamu-isozaki

ygean avatar Mar 05 '24 05:03 ygean

Ok! Training seems to work. Here's a wandb run on cifar 10. In 6gb vram, command to run this is

accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb

For the validation images, they will be shown like so for each validation image provided. The left is the input image and the right is the generated image original vs generated

The remaining parts that I can think of are

  • [x] Make log_validation support trackers other than wandb
  • [x] Make tqdm updates similar to other examples

I did find a bug where global step doesn't seem to go above 3000 but once that is fixed I'll open for review

seems that wandb page is not permit to access, it shows empty page

ygean avatar Mar 06 '24 05:03 ygean

Please add documentation about training that I am interested in that can validate training scripts @isamu-isozaki

Sorry for the late reply! And sry I think I accidentally deleted that run. Let me try rerunning when I get time. And yeah my command was just

accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb

the validation images are just taken from the cifar10 db basically. If you have a higher vram than 6gb then I think you can go ahead and increase the batch size for faster results.

isamu-isozaki avatar Mar 06 '24 20:03 isamu-isozaki

@sayakpaul sounds good! Let me do that by tonight

isamu-isozaki avatar Mar 07 '24 15:03 isamu-isozaki

I'll run some tests to verify the README explanation on hyperparameters and also explain some things on discriminator vs generator loss curves

isamu-isozaki avatar Mar 08 '24 20:03 isamu-isozaki

Sorry was a bit busy with finals but started some training on my laptop so will update the README with wandb links over the course of this week.

isamu-isozaki avatar Mar 25 '24 17:03 isamu-isozaki

No worries Isamu. Appreciate your take on this one (which still is very relevant for the community IMO).

sayakpaul avatar Mar 26 '24 02:03 sayakpaul

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 19 '24 15:04 github-actions[bot]

sorry was away for a bit again for a family emergency. I did notice the bug of training stopping after 3100 step returning so after that is fixed, will add some wandb links and I think this pr will be done

isamu-isozaki avatar Apr 22 '24 21:04 isamu-isozaki

After investigating a bit, I think this error is specific to my hardware. I'll test in google colab for a sanity check but in all other batch size and gradient accumulation steps (1, 1), (1, 2), (2, 1), (1, 8), (1, 4) wandb logging looks completely fine. But in batch size 2 and gradient accumulation step 8 it stops logging on step 3100 consistently.

isamu-isozaki avatar Apr 26 '24 16:04 isamu-isozaki

Thanks for investigating. I am happy to test it to on our internal cluster. Could you help me with a training command?

sayakpaul avatar Apr 26 '24 16:04 sayakpaul

@sayakpaul Thanks! That'll be very helpful. The command is

accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb --validation_steps=100 --checkpointing_steps=1000 --checkpoints_total_limit=2 --dataloader_num_workers=4

And the validation images are below in a folder called images frog bird car dog

Let me know if you have any issues/If you get the same bug! (also feel free to increase dataloader_num_workers if data time is too high)

isamu-isozaki avatar Apr 26 '24 17:04 isamu-isozaki

@isamu-isozaki started a run here: https://wandb.ai/sayakpaul/vqgan-training/runs/xxat8u5c.

Additionally, hosted the validation images here: https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images.

sayakpaul avatar Apr 27 '24 02:04 sayakpaul

@isamu-isozaki this looks really good. I just added a couple of nit comments.

One final thing before merging would be to add a test for the script. LMK if you need any help.

@sayakpaul thanks a lot and will do! I'll try doing this tomorrow and will lyk if I get any trouble

isamu-isozaki avatar Apr 27 '24 04:04 isamu-isozaki

@isamu-isozaki started a run here: https://wandb.ai/sayakpaul/vqgan-training/runs/xxat8u5c.

Additionally, hosted the validation images here: https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images.

Thanks a bunch! But I think you might have an oom 😅 On my end, I got ok results even with gradient accumulation 8 with batch size 1 here image

isamu-isozaki avatar Apr 27 '24 04:04 isamu-isozaki

Sorry forgot to update the other link: https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp. Running fine so far.

sayakpaul avatar Apr 27 '24 04:04 sayakpaul

@sayakpaul I think I finished the main clean up/tests. Let me know what you think!

isamu-isozaki avatar Apr 28 '24 04:04 isamu-isozaki

Additionally, would be great to have the code quality issue fixed.

sayakpaul avatar Apr 28 '24 08:04 sayakpaul

@sayakpaul I added a bit more to README for an extra guide to training+wandb links. Happy to get feedback. I think I'll add some more links for ema+ testing other models other than vgg19 if I get time. But overall I think done!

isamu-isozaki avatar Apr 28 '24 21:04 isamu-isozaki

Can we fix the failing tests?

sayakpaul avatar Apr 29 '24 02:04 sayakpaul

@sayakpaul sounds good! I think I figured out a fix

isamu-isozaki avatar Apr 29 '24 02:04 isamu-isozaki