diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[WIP] Sample images when checkpointing.

Open LucasSloan opened this issue 1 year ago • 21 comments

I based this on the in progress sampling in https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py due to the suggestion on https://github.com/huggingface/diffusers/pull/2030 that that was a good example to follow.

Unfortunately, this code doesn't work at present and I'm not sure why. I get the error RuntimeError: Input type (c10::Half) and bias type (float) should be the same, full stack trace:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:757 in    │
│ <module>                                                                                         │
│                                                                                                  │
│   754                                                                                            │
│   755                                                                                            │
│   756 if __name__ == "__main__":                                                                 │
│ ❱ 757 │   main()                                                                                 │
│   758                                                                                            │
│                                                                                                  │
│ /mnt/c/Users/lucas/Development/diffusers/examples/text_to_image/train_text_to_image.py:720 in    │
│ main                                                                                             │
│                                                                                                  │
│   717 │   │   │   │   │   │   │                                                                  │
│   718 │   │   │   │   │   │   │   # run inference                                                │
│   719 │   │   │   │   │   │   │   prompt = [args.validation_prompt]                              │
│ ❱ 720 │   │   │   │   │   │   │   images = pipeline(prompt, num_images_per_prompt=args.num_val   │
│   721 │   │   │   │   │   │   │                                                                  │
│   722 │   │   │   │   │   │   │   for i, image in enumerate(images):                             │
│   723 │   │   │   │   │   │   │   │   image.save(os.path.join(args.output_dir, f"sample-{globa   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/utils/_contextlib.py:115 in                 │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_sta │
│ ble_diffusion.py:611 in __call__                                                                 │
│                                                                                                  │
│   608 │   │   │   │   latent_model_input = self.scheduler.scale_model_input(latent_model_input   │
│   609 │   │   │   │                                                                              │
│   610 │   │   │   │   # predict the noise residual                                               │
│ ❱ 611 │   │   │   │   noise_pred = self.unet(                                                    │
│   612 │   │   │   │   │   latent_model_input,                                                    │
│   613 │   │   │   │   │   t,                                                                     │
│   614 │   │   │   │   │   encoder_hidden_states=prompt_embeds,                                   │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py:482 in      │
│ forward                                                                                          │
│                                                                                                  │
│   479 │   │   │   emb = emb + class_emb                                                          │
│   480 │   │                                                                                      │
│   481 │   │   # 2. pre-process                                                                   │
│ ❱ 482 │   │   sample = self.conv_in(sample)                                                      │
│   483 │   │                                                                                      │
│   484 │   │   # 3. down                                                                          │
│   485 │   │   down_block_res_samples = (sample,)                                                 │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/module.py:1488 in _call_impl     │
│                                                                                                  │
│   1485 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1486 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1487 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1488 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1489 │   │   # Do not call functions when jit is used                                          │
│   1490 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1491 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:463 in forward           │
│                                                                                                  │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
│ ❱  463 │   │   return self._conv_forward(input, self.weight, self.bias)                          │
│    464                                                                                           │
│    465 class Conv3d(_ConvNd):                                                                    │
│    466 │   __doc__ = r"""Applies a 3D convolution over an input signal composed of several inpu  │
│                                                                                                  │
│ /home/lucas/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:459 in _conv_forward     │
│                                                                                                  │
│    456 │   │   │   return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=sel  │
│    457 │   │   │   │   │   │   │   weight, bias, self.stride,                                    │
│    458 │   │   │   │   │   │   │   _pair(0), self.dilation, self.groups)                         │
│ ❱  459 │   │   return F.conv2d(input, weight, bias, self.stride,                                 │
│    460 │   │   │   │   │   │   self.padding, self.dilation, self.groups)                         │
│    461 │                                                                                         │
│    462 │   def forward(self, input: Tensor) -> Tensor:                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Input type (c10::Half) and bias type (float) should be the same

I tried to fix it on line 713 by setting torch_dtype=weight_dtype on the StableDiffusionPipeline, but that didn't work.

LucasSloan avatar Jan 29 '23 22:01 LucasSloan

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

+1

Weifeng-Chen avatar Jan 30 '23 07:01 Weifeng-Chen

Hey @LucasSloan,

Thanks a lot for opening the PR. Want to explain a bit what's going on here.

When training / fine-tuning stable diffusion models we noticed the following:

  • Weights should not be casted to float16 if those weights are trained, e.g. the UNet here. Instead such weights should be wrapped into mixed precision training which is done automatically with accelerate: https://huggingface.co/docs/accelerate/v0.15.0/en/package_reference/accelerator#accelerate.Accelerator.mixed_precision
  • Weights can/should be casted to float16 if those weights are used for training, but not trained, e.g. the text encoder here. Those weights can simply be casted to float16 with .to(torch.float16).

Now this works fine in training because accelerate takes care of everything. However during inference, we run into the problem that the UNet is in fp32 while the text encoder is in fp16. We could cast the unet into float16 which would work just fine for inference, but then we're breaking the model for the training after inference (remember trainable weights should stay in fp32). Thus, the solution here is to use autocast to automatically cast down the unet if necessary as shown in the PR review.

Note: We never recommend using autocast for pure inference but only for such special training cases. Does this make sense? Could you try whether it works with autocast?

Also could you maybe try to add the wandb and tensorflow logger here as well: https://github.com/huggingface/diffusers/blob/7d96b38b70407ec69816e71cdcfc4e0b41e26768/examples/dreambooth/train_dreambooth_lora.py#L948 ?

Also cc @patil-suraj

patrickvonplaten avatar Jan 31 '23 10:01 patrickvonplaten

I meet a similar issue and I solve it by disabling the safety checker.(since it wasn't used when training and maybe the type wasn't converted) . Meanwhile, map the latents, noise, noisy_latents to self.unet.dtype may help as well.(when I training with lightning, the variable defined in training loop wasn't converted...)

Weifeng-Chen avatar Jan 31 '23 12:01 Weifeng-Chen

Related issue and discussion

https://github.com/huggingface/diffusers/issues/2163#issuecomment-1410035940 https://github.com/huggingface/diffusers/pull/2173#pullrequestreview-1277057469

patil-suraj avatar Jan 31 '23 12:01 patil-suraj

Using torch.autocast() makes sense to me, but doesn't seem to resolve the issue. The other PR implementing similar functionality seems to be getting around it by not using fp16 weights at all (by reloading the full weights). Any other thoughts?

LucasSloan avatar Jan 31 '23 17:01 LucasSloan

Fixed it by not unwrapping the unet.

LucasSloan avatar Jan 31 '23 20:01 LucasSloan

Added wandb and tensorboard integration.

LucasSloan avatar Jan 31 '23 21:01 LucasSloan

So, I have been testing it since this morning. Here are my findings:

  • I could run the intermediate validation inference successfully with FP16, autocasting, and appropriate EMA updates.
  • However, during the final inference run which we usually do after pushing the pipeline files to the Hub, it is still failing.
Traceback (most recent call last):                                                                                                                                     | 0/30 [00:00<?, ?it/s]
  File "train_text_to_image.py", line 982, in <module>
    main()
  File "train_text_to_image.py", line 953, in main
    pipeline(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 611, in __call__
    noise_pred = self.unet(
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 482, in forward
    sample = self.conv_in(sample)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/envs/py38/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (c10::Half) and bias type (float) should be the same

I think we can potentially get around the issue is we cast the UNet and the Safety Checker modules to weight_dtype after initializing the pipeline:

f accelerator.is_main_process:
    unet = accelerator.unwrap_model(unet)
    if args.use_ema:
        ema_unet.copy_to(unet.parameters())

    pipeline = StableDiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        revision=args.revision,
    )

    ...

    # before running inference
    pipeline.unet = unet.to(weight_dtype)
    pipeline.safety_checker = pipeline.safety_checker.to(weight_dtype)

Here's my gist that contains the modified train_text_to_image.py script, ema.py script (thanks to @patil-suraj for the suggestions here), and the execution instructions.

Let me know if anything is unclear.

sayakpaul avatar Feb 02 '23 07:02 sayakpaul

@sayakpaul Thanks for testing this! Think we should wrap the pipe call in autocast here as well, we can-not explicitly cast the model here as we always save models in full-precision.

patil-suraj avatar Feb 02 '23 12:02 patil-suraj

@patil-suraj here's what I did:

  • While saving the final pipeline, I didn't use the text_encoder and vae to avoid the mismatch issues:
pipeline = StableDiffusionPipeline.from_pretrained(
    args.pretrained_model_name_or_path,
    unet=unet,
    revision=args.revision,
)
pipeline.save_pretrained(args.output_dir)
  • And then before running the final inference, I move the pipeline to accelerator.device.

With these changes, things seem to work.

Would you like to test it with the gist (from https://github.com/huggingface/diffusers/pull/2157#issuecomment-1413279450) and the changes above?

Also, would you like me to open a PR adding the store() and restore() methods to EMAModel?

sayakpaul avatar Feb 02 '23 15:02 sayakpaul

I've been testing this change, and it seems like it doesn't actually do the fine tuning. Possibly creating the pipeline overwrites the weights, losing the training progress? I'll do some more testing to confirm, but if someone else could try (and maybe think about why it would happen), that'd be great.

LucasSloan avatar Feb 02 '23 17:02 LucasSloan

Another feature I'd like to have is the ability to provide multiple prompts. However, when I added action="append" to the --validation_prompt argument, I get an error on this line:

accelerator.init_trackers("text2image-fine-tune", config=vars(args))

Where tensorboard doesn't like the fact that the --validation_prompt argument is a list instead of one of the basic types it supports. Does anyone have a suggestion for fixing that?

LucasSloan avatar Feb 02 '23 17:02 LucasSloan

I've been testing this change, and it seems like it doesn't actually do the fine tuning. Possibly creating the pipeline overwrites the weights, losing the training progress? I'll do some more testing to confirm, but if someone else could try (and maybe think about why it would happen), that'd be great.

Could you try with the train_text_to_image.py script mentioned in this gist (as mentioned here)? Also, take note of the changes suggested in https://github.com/huggingface/diffusers/pull/2157#issuecomment-1413918300.

Another feature I'd like to have is the ability to provide multiple prompts. However, when I added action="append" to the --validation_prompt argument, I get an error on this line:

accelerator.init_trackers("text2image-fine-tune", config=vars(args))

Where tensorboard doesn't like the fact that the --validation_prompt argument is a list instead of one of the basic types it supports. Does anyone have a suggestion for fixing that?

Our example scripts are meant to be as simple as possible. So, I would de-provision this feature for the time being :) But to support it in the tracker, you could maybe try to use nargs? This post has some good reference examples.

sayakpaul avatar Feb 02 '23 17:02 sayakpaul

@sayakpaul if you could add the methods to EMAModel, that'd be great.

LucasSloan avatar Feb 03 '23 00:02 LucasSloan

Figured out why the model wasn't training - I was using the --enable_xformers_memory_efficient_attention flag. Is that known not to work or do I have an issue with my set up?

LucasSloan avatar Feb 03 '23 03:02 LucasSloan

Tried re-installing xformers, and instead of no training, the safety_checker tripped (training on the Pokemon dataset, validation prompt "Yoda"). Tried again, disabling the safety_checker, and I got black images anyway, along with the error message:

/home/lucas/.local/lib/python3.8/site-packages/diffusers/pipelines/pipeline_utils.py:813: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")

LucasSloan avatar Feb 03 '23 17:02 LucasSloan

Thanks a lot @LucasSloan ! Would it be okay if I go into our PR and add the necessary changes for the EMAModel ? The rest of the changes are looking good !

patil-suraj avatar Feb 08 '23 08:02 patil-suraj

@LucasSloan we recently merged https://github.com/huggingface/diffusers/pull/2302. Should be good to test your changes with EMA now.

sayakpaul avatar Feb 16 '23 14:02 sayakpaul

Added EMA support.

Can someone look into those test failures? I wouldn't expect changes to this file do anything... Are they known issues?

LucasSloan avatar Feb 18 '23 21:02 LucasSloan

Can someone look into those test failures? I wouldn't expect changes to this file do anything... Are they known issues?

Thanks a lot for the changes. Yeah the failing tests are unrelated.

sayakpaul avatar Feb 20 '23 07:02 sayakpaul

Rebased and the tests fixed themselves.

LucasSloan avatar Mar 03 '23 17:03 LucasSloan

@sayakpaul do you think this PR still makes sense? Should we try to get it in?

I think so, yes! I am going to review and test the PR tomorrow and comment accordingly.

sayakpaul avatar Mar 13 '23 06:03 sayakpaul

@LucasSloan I tried testing the code today and I observed the following command:

accelerate launch --mixed_precision="fp16"  examples/text_to_image/train_text_to_image.py   \
	--pretrained_model_name_or_path=$MODEL_NAME   --dataset_name=$DATASET_NAME   \
	--use_ema   \
	--resolution=512 --center_crop --random_flip   \
	--train_batch_size=1   \
	--gradient_accumulation_steps=4   --gradient_checkpointing   \
	--max_train_steps=20 --max_train_samples=5 \
	--enable_xformers_memory_efficient_attention \
	--learning_rate=1e-05   --max_grad_norm=1   --lr_scheduler="constant" --lr_warmup_steps=0   \
	--mixed_precision="fp16" \
	--validation_prompt="cute dragon creature" --num_validation_images=3 --validation_steps=1  \
	--seed=666  \
	--output_dir="sd-pokemon-model"

there's no validation inference being done. Is there something I am missing out on?

sayakpaul avatar Mar 14 '23 08:03 sayakpaul

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Apr 07 '23 15:04 github-actions[bot]