diffusers
diffusers copied to clipboard
Adding VQGAN Training script
What does this PR do?
This is a vqgan training script ported from taming-transformers and from lucidrian's muse-maskgit repo here and open-muse. I'm planning to test this on the cifar10 dataset to confirm it works
Some steps missing/need confirmation are
- [x] Confirm einops and timm can be external dependencies. If not convert these ops to native pytorch
- [x] Test on cifar10
- [x] Add in test to test_models_vq and test_models_vae for the slight modification
Fixes #4702
Before submitting
- [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Once confirmed it works with cifar10 will remove the draft part
I was able to start training this script. And I removed the einops dependencies. The only additional dependency so far is timm. I plan to run this overnight on cifar with 128 image resolution and then remove the draft from this pr. Also let me know if anyone knows a good VQModel config that's easy to train/fast
Ok! Training seems to work. Here's a wandb run on cifar 10. In 6gb vram, command to run this is
accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb
For the validation images, they will be shown like so for each validation image provided. The left is the input image and the right is the generated image
The remaining parts that I can think of are
- [x] Make log_validation support trackers other than wandb
- [x] Make tqdm updates similar to other examples
I did find a bug where global step doesn't seem to go above 3000 but once that is fixed I'll open for review
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.
The main logic is done so I think it's ready for review. For the 3000 step bug I'm currently running training to see if it happens again after the fixes.
Ok! Seems like it was a hardware issue(I think). Got steps 3100. Script should be ready for review.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Hi there, what is the current status of this PR? It seems that everything works well. Will this be merged?
ah let me resolve merge conflicts in morning
Please add documentation about training that I am interested in that can validate training scripts @isamu-isozaki
Ok! Training seems to work. Here's a wandb run on cifar 10. In 6gb vram, command to run this is
accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb
For the validation images, they will be shown like so for each validation image provided. The left is the input image and the right is the generated image
The remaining parts that I can think of are
- [x] Make log_validation support trackers other than wandb
- [x] Make tqdm updates similar to other examples
I did find a bug where global step doesn't seem to go above 3000 but once that is fixed I'll open for review
seems that wandb page is not permit to access, it shows empty page
Please add documentation about training that I am interested in that can validate training scripts @isamu-isozaki
Sorry for the late reply! And sry I think I accidentally deleted that run. Let me try rerunning when I get time. And yeah my command was just
accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg images/horse.jpg images/ship.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb
the validation images are just taken from the cifar10 db basically. If you have a higher vram than 6gb then I think you can go ahead and increase the batch size for faster results.
@sayakpaul sounds good! Let me do that by tonight
I'll run some tests to verify the README explanation on hyperparameters and also explain some things on discriminator vs generator loss curves
Sorry was a bit busy with finals but started some training on my laptop so will update the README with wandb links over the course of this week.
No worries Isamu. Appreciate your take on this one (which still is very relevant for the community IMO).
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
sorry was away for a bit again for a family emergency. I did notice the bug of training stopping after 3100 step returning so after that is fixed, will add some wandb links and I think this pr will be done
After investigating a bit, I think this error is specific to my hardware. I'll test in google colab for a sanity check but in all other batch size and gradient accumulation steps (1, 1), (1, 2), (2, 1), (1, 8), (1, 4) wandb logging looks completely fine. But in batch size 2 and gradient accumulation step 8 it stops logging on step 3100 consistently.
Thanks for investigating. I am happy to test it to on our internal cluster. Could you help me with a training command?
@sayakpaul Thanks! That'll be very helpful. The command is
accelerate launch train_vqgan.py --dataset_name=cifar10 --image_column=img --validation_images images/bird.jpg images/car.jpg images/dog.jpg images/frog.jpg --resolution=128 --train_batch_size=2 --gradient_accumulation_steps=8 --report_to=wandb --validation_steps=100 --checkpointing_steps=1000 --checkpoints_total_limit=2 --dataloader_num_workers=4
And the validation images are below in a folder called images
Let me know if you have any issues/If you get the same bug! (also feel free to increase dataloader_num_workers if data time is too high)
@isamu-isozaki started a run here: https://wandb.ai/sayakpaul/vqgan-training/runs/xxat8u5c.
Additionally, hosted the validation images here: https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images.
@isamu-isozaki this looks really good. I just added a couple of nit comments.
One final thing before merging would be to add a test for the script. LMK if you need any help.
@sayakpaul thanks a lot and will do! I'll try doing this tomorrow and will lyk if I get any trouble
@isamu-isozaki started a run here: https://wandb.ai/sayakpaul/vqgan-training/runs/xxat8u5c.
Additionally, hosted the validation images here: https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images.
Thanks a bunch! But I think you might have an oom 😅
On my end, I got ok results even with gradient accumulation 8 with batch size 1 here
Sorry forgot to update the other link: https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp. Running fine so far.
@sayakpaul I think I finished the main clean up/tests. Let me know what you think!
Additionally, would be great to have the code quality issue fixed.
@sayakpaul I added a bit more to README for an extra guide to training+wandb links. Happy to get feedback. I think I'll add some more links for ema+ testing other models other than vgg19 if I get time. But overall I think done!
Can we fix the failing tests?
@sayakpaul sounds good! I think I figured out a fix