[WIP] Add Adversarial Diffusion Distillation (ADD) Script
What does this PR do?
This PR adds an example script for adversarial diffusion distillation (ADD) (paper, code), a distillation + adversarial training method used to distill SD/SD-XL Turbo.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@patrickvonplaten @sayakpaul @patil-suraj
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
There's currently an issue that happens in the script as follows: if student_timesteps takes on the value noise_scheduler.config.num_train_timesteps - 1 (= 999), alpha_schedule[noise_scheduler.config.num_train_timesteps - 1] == 0 because noise_scheduler has rescale_betas_zero_snr == True. When we go to convert the student unet's noise prediction to a predicted original sample in Denoiser, we then end up dividing by alpha_schedule[999] == 0 and nans/infs start appearing in student_x_0, which ultimately causes the gradients to become nans.
(Note that if rescale_betas_zero_snr == False, alpha_schedule[noise_scheduler.config.num_train_timesteps - 1] $\approx 0.0683$, which is far away from 0, as noted by the paper that introduced rescaling to zero terminal SNR. Also note that the same thing could theoretically happen with teacher_timesteps, but it's much less likely since we sample over the whole range of training timesteps $\{0, 1, ..., 999\}$ rather than just 4 timesteps, one of which is $999$.)
The ADD paper indicates that they both use $999$ as a student timestep and zero terminal SNR:
I'm not sure how they handle this divide by zero issue, perhaps they divide by alphas + eps for some small eps > 0 when converting from a noise prediction to an original sample prediction? Thoughts @patil-suraj @sayakpaul?
[Edit: I see that EulerDiscreteScheduler handles this by setting the last alpha to a small positive value:
https://github.com/huggingface/diffusers/blob/79c380bc80051d8f82a38c3a7df6f8f4efd1633d/src/diffusers/schedulers/scheduling_euler_discrete.py#L214-L217
So that's another option.]
If I set alphas_schedule[-1] == 2**-24, the student predicted original sample student_x_0 will not have any nans, but whenever student_timesteps is noise_scheduler.config.num_train_timesteps - 1 the predicted x_0 will have large positive and negative values (as expected). When we decode using the VAE, the decoded output student_gen_image will then have nans at those locations (at least for the small random test VAE at dg845/tiny-random-stable-diffusion, not sure what happens if a full pretrained SD/SD-XL checkpoint is used). This is true even if we use a small positive value that's several orders of magnitude higher, e.g. alphas_schedule[-1] == 2**-12.
Thanks for all the details. Probably a very dumb ask but have you seen any widely different training dynamics without enabling zero-terminal SNR? If not, we could definitely try that out.
But if I were to try out stuff in priority:
- I would try out Euler
- Replicate these changes with the current sampler but looks like you have already tried it
- Follow up with the ADD authors about this phenomenon
Probably a very dumb ask but have you seen any widely different training dynamics without enabling zero-terminal SNR?
I've only tested on very small toy checkpoints like dg845/tiny-random-stable-diffusion due to a lack of computational resources, so I'm not sure if the training dynamics are different. My guess is that without enforcing zero terminal SNR the distilled checkpoint might have the medium brightness sample issue described in https://arxiv.org/pdf/2305.08891.pdf.
(BTW, would it be possible to get some GPU resources to help test the script?😅)
(BTW, would it be possible to get some GPU resources to help test the script?😅)
Checking that internally. Will keep you posted.
I will start testing this and will keep you posted.
Excuse my ignorance but how do I run it? I seem to get this error trying to run either the lora or SD1.5 version
accelerate launch train_add_distill_lora_sd_wds.py --pretrained_teacher_model C:\Users\Pablo\Documents\mobians_api\sonicDiffusionV4 --train_shards_path_or_url laion/conceptual-captions-12m-webdataset
Traceback (most recent call last):
File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 2114, in
I think there are a few additional arguments that need to be explicitly supplied for the scripts to not raise an error. Something close to the minimal set of arguments needed is
accelerate launch examples/add/train_add_distill_lora_sd_wds.py \
--pretrained_teacher_model="<teacher_model>" \
--train_shards_path_or_url="<dataset>" \
--output_dir="<output_dir>" \
--max_train_steps=1 \
--max_train_samples=20 \
--dataloader_num_workers=8 \
assuming the other default values work (for example, --train_batch_size is not so big that it leads to an OOM error).
Note that the scripts are a work in progress and there's no guarantee that they work currently.
Got it running, ran into bug saving though. Validation images also looked like random noise also.
Steps: 40%|▍| 400/1000 [35:56<53:16, 5.33s/it, d_total_loss=1.46, g_adv_loss=-.086, g_distill_loss=0.123, g_total_loss01/10/2024 16:24:08 - INFO - __main__ - Running validation...
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of ../sonicDiffusionV4. | 0/5 [00:00<?, ?it/s]
Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of ../sonicDiffusionV4.
Loaded scheduler as PNDMScheduler from `scheduler` subfolder of ../sonicDiffusionV4. | 3/5 [00:00<00:00, 5.95it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:00<00:00, 9.88it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Steps: 50%|▌| 500/1000 [44:53<43:40, 5.24s/it, d_total_loss=1.51, g_adv_loss=-.227, g_distill_loss=0.00964, g_total_lo01/10/2024 16:33:06 - INFO - accelerate.accelerator - Saving current state to add_model/checkpoint-500
Configuration saved in add_model/checkpoint-500/unet/config.json
Model weights saved in add_model/checkpoint-500/unet/diffusion_pytorch_model.safetensors
Traceback (most recent call last):
File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 2010, in <module>
main(args)
File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1957, in main
accelerator.save_state(save_path)
File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2706, in save_state
hook(self._models, weights, output_dir)
File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1500, in save_model_hook
model.save_pretrained(os.path.join(output_dir, "unet"))
File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Discriminator' object has no attribute 'save_pretrained'
wandb: \ 17.403 MB of 17.417 MB uploaded
wandb: Run history:
wandb: d_adv_loss_fake ▆▆▃▄▂▃▄▅▄▄▇▃▅▃█▇▃▆▂▃█▅▄▂▂▂▅▃▂▃▄▇▆▄▃▃▆▁▁▄
wandb: d_adv_loss_real ▆▆▆▄▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb: d_loss_real ▆▆▆▅▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb: d_r1_regularizer ███▆▇▄▆▄▇▆▄▅▃▆▁▁▂▄▆▂▃▄▄▂▄▃▃▅▃▂▆▁▄▃▅▂▁▄▇▄
wandb: d_total_loss ██▅▅▅▄▅▅▆▅▇▄▅▅▇▆▃▇▅▃█▅▅▁▂▂▅▃▂▃▅▅▆▄▄▂▅▁▄▅
wandb: g_adv_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▂▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb: g_distill_loss ▄▃▃▂▃▁▅▃▅▁▂▃▃▄▁▁▂▃▂▂▂▁▂▂▁▁▅▃▅▂▃▁▂▂▁▂█▂▁▃
wandb: g_total_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▃▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb: lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:
wandb: Run summary:
wandb: d_adv_loss_fake 1.29867
wandb: d_adv_loss_real 0.19679
wandb: d_loss_real 0.20999
wandb: d_r1_regularizer 1319.4856
wandb: d_total_loss 1.50865
wandb: g_adv_loss -0.22699
wandb: g_distill_loss 0.00964
wandb: g_total_loss -0.20288
wandb: lr 0.0001
wandb:
wandb: 🚀 View run unique-armadillo-16 at: https://wandb.ai/metal/text2image-fine-tune/runs/0tks0q98
wandb: Synced 5 W&B file(s), 32 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240110_154811-0tks0q98/logs
Traceback (most recent call last):
File "/home/metal/dpo_test/venv/bin/accelerate", line 8, in <module>
sys.exit(main())
File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
args.func(args)
File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
simple_launcher(args)
File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/metal/dpo_test/venv/bin/python3', 'train_add_distill_sd_wds.py', '--pretrained_teacher_model=../sonicDiffusionV4', '--train_shards_path_or_url=pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true', '--output_dir=add_model', '--max_train_steps=1000', '--max_train_samples=4000000', '--dataloader_num_workers=8', '--train_batch_size=2', '--allow_tf32', '--mixed_precision=fp16', '--report_to=wandb', '--gradient_checkpointing', '--use_8bit_adam', '--gradient_accumulation_steps=8', '--allow_nonzero_terminal_snr']' returned non-zero exit status 1.
How far away is this pr from being merged?
Hi @SteamedGit the ADD implementation is nominally complete but I have not been able to test whether the script can distill good models (e.g. for SD v1.5) yet.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Not stale.
@sayakpaul @dg845 Great job! Can someone please confirm if the effectiveness of this PR has been verified?@
regarding computing sds loss i suggest taking a look at https://arxiv.org/abs/2306.04619 which tends to produce a better target
@cjt222 sorry, I haven't been able to finish testing it yet. Will hopefully find more time to work on it soon 😅.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.