diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[WIP] Add Adversarial Diffusion Distillation (ADD) Script

Open dg845 opened this issue 2 years ago • 14 comments

What does this PR do?

This PR adds an example script for adversarial diffusion distillation (ADD) (paper, code), a distillation + adversarial training method used to distill SD/SD-XL Turbo.

Before submitting

  • [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline?
  • [x] Did you read our philosophy doc (important for complex PRs)?
  • [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [ ] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@patrickvonplaten @sayakpaul @patil-suraj

dg845 avatar Dec 23 '23 23:12 dg845

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

There's currently an issue that happens in the script as follows: if student_timesteps takes on the value noise_scheduler.config.num_train_timesteps - 1 (= 999), alpha_schedule[noise_scheduler.config.num_train_timesteps - 1] == 0 because noise_scheduler has rescale_betas_zero_snr == True. When we go to convert the student unet's noise prediction to a predicted original sample in Denoiser, we then end up dividing by alpha_schedule[999] == 0 and nans/infs start appearing in student_x_0, which ultimately causes the gradients to become nans.

(Note that if rescale_betas_zero_snr == False, alpha_schedule[noise_scheduler.config.num_train_timesteps - 1] $\approx 0.0683$, which is far away from 0, as noted by the paper that introduced rescaling to zero terminal SNR. Also note that the same thing could theoretically happen with teacher_timesteps, but it's much less likely since we sample over the whole range of training timesteps $\{0, 1, ..., 999\}$ rather than just 4 timesteps, one of which is $999$.)

The ADD paper indicates that they both use $999$ as a student timestep and zero terminal SNR:

add_noise_schedule

I'm not sure how they handle this divide by zero issue, perhaps they divide by alphas + eps for some small eps > 0 when converting from a noise prediction to an original sample prediction? Thoughts @patil-suraj @sayakpaul?

[Edit: I see that EulerDiscreteScheduler handles this by setting the last alpha to a small positive value:

https://github.com/huggingface/diffusers/blob/79c380bc80051d8f82a38c3a7df6f8f4efd1633d/src/diffusers/schedulers/scheduling_euler_discrete.py#L214-L217

So that's another option.]

dg845 avatar Jan 02 '24 19:01 dg845

If I set alphas_schedule[-1] == 2**-24, the student predicted original sample student_x_0 will not have any nans, but whenever student_timesteps is noise_scheduler.config.num_train_timesteps - 1 the predicted x_0 will have large positive and negative values (as expected). When we decode using the VAE, the decoded output student_gen_image will then have nans at those locations (at least for the small random test VAE at dg845/tiny-random-stable-diffusion, not sure what happens if a full pretrained SD/SD-XL checkpoint is used). This is true even if we use a small positive value that's several orders of magnitude higher, e.g. alphas_schedule[-1] == 2**-12.

dg845 avatar Jan 02 '24 22:01 dg845

Thanks for all the details. Probably a very dumb ask but have you seen any widely different training dynamics without enabling zero-terminal SNR? If not, we could definitely try that out.

But if I were to try out stuff in priority:

  • I would try out Euler
  • Replicate these changes with the current sampler but looks like you have already tried it
  • Follow up with the ADD authors about this phenomenon

sayakpaul avatar Jan 03 '24 05:01 sayakpaul

Probably a very dumb ask but have you seen any widely different training dynamics without enabling zero-terminal SNR?

I've only tested on very small toy checkpoints like dg845/tiny-random-stable-diffusion due to a lack of computational resources, so I'm not sure if the training dynamics are different. My guess is that without enforcing zero terminal SNR the distilled checkpoint might have the medium brightness sample issue described in https://arxiv.org/pdf/2305.08891.pdf.

(BTW, would it be possible to get some GPU resources to help test the script?😅)

dg845 avatar Jan 05 '24 03:01 dg845

(BTW, would it be possible to get some GPU resources to help test the script?😅)

Checking that internally. Will keep you posted.

sayakpaul avatar Jan 05 '24 03:01 sayakpaul

I will start testing this and will keep you posted.

patil-suraj avatar Jan 05 '24 09:01 patil-suraj

Excuse my ignorance but how do I run it? I seem to get this error trying to run either the lora or SD1.5 version

accelerate launch train_add_distill_lora_sd_wds.py --pretrained_teacher_model C:\Users\Pablo\Documents\mobians_api\sonicDiffusionV4 --train_shards_path_or_url laion/conceptual-captions-12m-webdataset

Traceback (most recent call last): File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 2114, in main(args) File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 1693, in main dataset = SDText2ImageDataset( File "C:\Users\Pablo\Downloads\diffusers\examples\add\train_add_distill_lora_sd_wds.py", line 238, in init num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'

Metal079 avatar Jan 08 '24 23:01 Metal079

I think there are a few additional arguments that need to be explicitly supplied for the scripts to not raise an error. Something close to the minimal set of arguments needed is

accelerate launch examples/add/train_add_distill_lora_sd_wds.py \
    --pretrained_teacher_model="<teacher_model>" \
    --train_shards_path_or_url="<dataset>" \
    --output_dir="<output_dir>" \
    --max_train_steps=1 \
    --max_train_samples=20 \
    --dataloader_num_workers=8 \

assuming the other default values work (for example, --train_batch_size is not so big that it leads to an OOM error).

Note that the scripts are a work in progress and there's no guarantee that they work currently.

dg845 avatar Jan 10 '24 04:01 dg845

Got it running, ran into bug saving though. Validation images also looked like random noise also.

Steps:  40%|▍| 400/1000 [35:56<53:16,  5.33s/it, d_total_loss=1.46, g_adv_loss=-.086, g_distill_loss=0.123, g_total_loss01/10/2024 16:24:08 - INFO - __main__ - Running validation...
                                                                                                                       Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of ../sonicDiffusionV4.               | 0/5 [00:00<?, ?it/s]
Loaded text_encoder as CLIPTextModel from `text_encoder` subfolder of ../sonicDiffusionV4.
                                                                                                                       Loaded scheduler as PNDMScheduler from `scheduler` subfolder of ../sonicDiffusionV4.       | 3/5 [00:00<00:00,  5.95it/s]
Loading pipeline components...: 100%|█████████████████████████████████████████████████████| 5/5 [00:00<00:00,  9.88it/s]
You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
Steps:  50%|▌| 500/1000 [44:53<43:40,  5.24s/it, d_total_loss=1.51, g_adv_loss=-.227, g_distill_loss=0.00964, g_total_lo01/10/2024 16:33:06 - INFO - accelerate.accelerator - Saving current state to add_model/checkpoint-500
Configuration saved in add_model/checkpoint-500/unet/config.json
Model weights saved in add_model/checkpoint-500/unet/diffusion_pytorch_model.safetensors
Traceback (most recent call last):
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 2010, in <module>
    main(args)
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1957, in main
    accelerator.save_state(save_path)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2706, in save_state
    hook(self._models, weights, output_dir)
  File "/home/metal/dpo_test/add/train_add_distill_sd_wds.py", line 1500, in save_model_hook
    model.save_pretrained(os.path.join(output_dir, "unet"))
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1695, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'Discriminator' object has no attribute 'save_pretrained'
wandb: \ 17.403 MB of 17.417 MB uploaded
wandb: Run history:
wandb:  d_adv_loss_fake ▆▆▃▄▂▃▄▅▄▄▇▃▅▃█▇▃▆▂▃█▅▄▂▂▂▅▃▂▃▄▇▆▄▃▃▆▁▁▄
wandb:  d_adv_loss_real ▆▆▆▄▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb:      d_loss_real ▆▆▆▅▇▅▅▃▇▅▄▅▄▇▂▂▃▄▇▄▄▄▆▃▃▃▃▄▃▃▅▁▄▃▅▂▂▄█▅
wandb: d_r1_regularizer ███▆▇▄▆▄▇▆▄▅▃▆▁▁▂▄▆▂▃▄▄▂▄▃▃▅▃▂▆▁▄▃▅▂▁▄▇▄
wandb:     d_total_loss ██▅▅▅▄▅▅▆▅▇▄▅▅▇▆▃▇▅▃█▅▅▁▂▂▅▃▂▃▅▅▆▄▄▂▅▁▄▅
wandb:       g_adv_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▂▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb:   g_distill_loss ▄▃▃▂▃▁▅▃▅▁▂▃▃▄▁▁▂▃▂▂▂▁▂▂▁▁▅▃▅▂▃▁▂▂▁▂█▂▁▃
wandb:     g_total_loss ▂▂▄▃▅▅▄▄▄▄▂▄▄▅▁▂▄▂▅▆▁▃▆█▆▅▃▆▆▄▄▃▄▆▅▆▃▇▇▅
wandb:               lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:
wandb: Run summary:
wandb:  d_adv_loss_fake 1.29867
wandb:  d_adv_loss_real 0.19679
wandb:      d_loss_real 0.20999
wandb: d_r1_regularizer 1319.4856
wandb:     d_total_loss 1.50865
wandb:       g_adv_loss -0.22699
wandb:   g_distill_loss 0.00964
wandb:     g_total_loss -0.20288
wandb:               lr 0.0001
wandb:
wandb: 🚀 View run unique-armadillo-16 at: https://wandb.ai/metal/text2image-fine-tune/runs/0tks0q98
wandb: Synced 5 W&B file(s), 32 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-20240110_154811-0tks0q98/logs
Traceback (most recent call last):
  File "/home/metal/dpo_test/venv/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 47, in main
    args.func(args)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1017, in launch_command
    simple_launcher(args)
  File "/home/metal/dpo_test/venv/lib/python3.10/site-packages/accelerate/commands/launch.py", line 637, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/home/metal/dpo_test/venv/bin/python3', 'train_add_distill_sd_wds.py', '--pretrained_teacher_model=../sonicDiffusionV4', '--train_shards_path_or_url=pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true', '--output_dir=add_model', '--max_train_steps=1000', '--max_train_samples=4000000', '--dataloader_num_workers=8', '--train_batch_size=2', '--allow_tf32', '--mixed_precision=fp16', '--report_to=wandb', '--gradient_checkpointing', '--use_8bit_adam', '--gradient_accumulation_steps=8', '--allow_nonzero_terminal_snr']' returned non-zero exit status 1.


Metal079 avatar Jan 10 '24 23:01 Metal079

How far away is this pr from being merged?

SteamedGit avatar Feb 03 '24 10:02 SteamedGit

Hi @SteamedGit the ADD implementation is nominally complete but I have not been able to test whether the script can distill good models (e.g. for SD v1.5) yet.

dg845 avatar Feb 05 '24 00:02 dg845

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Mar 05 '24 15:03 github-actions[bot]

Not stale.

sayakpaul avatar Mar 05 '24 18:03 sayakpaul

@sayakpaul @dg845 Great job! Can someone please confirm if the effectiveness of this PR has been verified?@

cjt222 avatar Mar 13 '24 08:03 cjt222

regarding computing sds loss i suggest taking a look at https://arxiv.org/abs/2306.04619 which tends to produce a better target

erliding avatar Mar 13 '24 10:03 erliding

@cjt222 sorry, I haven't been able to finish testing it yet. Will hopefully find more time to work on it soon 😅.

dg845 avatar Mar 18 '24 00:03 dg845

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 11 '24 15:04 github-actions[bot]