mmagic icon indicating copy to clipboard operation
mmagic copied to clipboard

[Bug] It seems that there is a logic error in the Dreambooth generate_class_prior_images function

Open wangqiang9 opened this issue 2 years ago • 1 comments
trafficstars

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmagic

Environment

Environment is correct.

Reproduces the problem - code sample

In the class Dreambooth, prior_loss_weight is default to set 0. This way the program does work. However, if it is changed to 1 or other non-zero values, the following error will be reported:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/tools/train.py:114 in <module>                 │
│                                                                                                  │
│   111                                                                                            │
│   112                                                                                            │
│   113 if __name__ == '__main__':                                                                 │
│ ❱ 114 │   main()                                                                                 │
│   115                                                                                            │
│                                                                                                  │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/tools/train.py:107 in main                     │
│                                                                                                  │
│   104 │   print_colored_log(f'Log directory: {runner._log_dir}')                                 │
│   105 │                                                                                          │
│   106 │   # start training                                                                       │
│ ❱ 107 │   runner.train()                                                                         │
│   108 │                                                                                          │
│   109 │   print_colored_log(f'Log saved under {runner._log_dir}')                                │
│   110 │   print_colored_log(f'Checkpoint saved under {cfg.work_dir}')                            │
│                                                                                                  │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/runner.py:1721 in train       │
│                                                                                                  │
│   1718 │   │   # This must be called **AFTER** model has been wrapped.                           │
│   1719 │   │   self._maybe_compile('train_step')                                                 │
│   1720 │   │                                                                                     │
│ ❱ 1721 │   │   model = self.train_loop.run()  # type: ignore                                     │
│   1722 │   │   self.call_hook('after_run')                                                       │
│   1723 │   │   return model                                                                      │
│   1724                                                                                           │
│                                                                                                  │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/loops.py:278 in run           │
│                                                                                                  │
│   275 │   │   │   self.runner.model.train()                                                      │
│   276 │   │   │                                                                                  │
│   277 │   │   │   data_batch = next(self.dataloader_iterator)                                    │
│ ❱ 278 │   │   │   self.run_iter(data_batch)                                                      │
│   279 │   │   │                                                                                  │
│   280 │   │   │   self._decide_current_val_interval()                                            │
│   281 │   │   │   if (self.runner.val_loop is not None                                           │
│                                                                                                  │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/loops.py:302 in run_iter      │
│                                                                                                  │
│   299 │   │   # synchronization during gradient accumulation process.                            │
│   300 │   │   # outputs should be a dict of loss.                                                │
│   301 │   │   outputs = self.runner.model.train_step(                                            │
│ ❱ 302 │   │   │   data_batch, optim_wrapper=self.runner.optim_wrapper)                           │
│   303 │   │                                                                                      │
│   304 │   │   self.runner.call_hook(                                                             │
│   305 │   │   │   'after_train_iter',                                                            │
│                                                                                                  │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/model/wrappers/seperate_distributed. │
│ py:102 in train_step                                                                             │
│                                                                                                  │
│    99 │   │   Returns:                                                                           │
│   100 │   │   │   Dict[str, torch.Tensor]: A dict of tensor for logging.                         │
│   101 │   │   """                                                                                │
│ ❱ 102 │   │   return self.module.train_step(data, optim_wrapper)                                 │
│   103 │                                                                                          │
│   104 │   def val_step(self, data: Union[dict, tuple, list]) -> list:                            │
│   105 │   │   """Gets the prediction of module during validation process.                        │
│                                                                                                  │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/mmagic/models/editors/dreambooth/dreambooth.py │
│ :253 in train_step                                                                               │
│                                                                                                  │
│   250 │   │   │   #print(f"=======self.prior_loss_weight: {self.prior_loss_weight}")             │
│   251 │   │   │   if self.prior_loss_weight != 0:                                                │
│   252 │   │   │   │   # image and prompt for prior preservation                                  │
│ ❱ 253 │   │   │   │   self.generate_class_prior_images(num_batches=num_batches)                  │
│   254 │   │   │   │   class_images_used = []                                                     │
│   255 │   │   │   │   for _ in range(num_batches):                                               │
│   256 │   │   │   │   │   idx = random.randint(0, len(self.class_images) - 1)                    │
│                                                                                                  │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/torch/autograd/grad_mode.py:27 in             │
│ decorate_context                                                                                 │
│                                                                                                  │
│    24 │   │   @functools.wraps(func)                                                             │
│    25 │   │   def decorate_context(*args, **kwargs):                                             │
│    26 │   │   │   with self.clone():                                                             │
│ ❱  27 │   │   │   │   return func(*args, **kwargs)                                               │
│    28 │   │   return cast(F, decorate_context)                                                   │
│    29 │                                                                                          │
│    30 │   def _wrap_generator(self, func):                                                       │
│                                                                                                  │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/mmagic/models/editors/dreambooth/dreambooth.py │
│ :124 in generate_class_prior_images                                                              │
│                                                                                                  │
│   121 │   │                                                                                      │
│   122 │   │   print("=============self.class_prior_prompt: ", self.class_prior_prompt)           │
│   123 │   │   assert self.class_prior_prompt is not None, (                                      │
│ ❱ 124 │   │   │   '\'class_prior_prompt\' must be set when \'prior_loss_weight\' is '            │
│   125 │   │   │   'larger than 0.')                                                              │
│   126 │   │   assert self.num_class_images is not None, (                                        │
│   127 │   │   │   '\'num_class_images\' must be set when \'prior_loss_weight\' is '              │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: 'class_prior_prompt' must be set when 'prior_loss_weight' is larger than 0.

I think there is a logic error in the def generate_class_prior_images function, which fires when the prior_loss_weight parameter is not equal to 0.

Reproduces the problem - command or script

https://github.com/open-mmlab/mmagic/blob/main/configs/dreambooth/README.md

Reproduces the problem - error message

ditto.

Additional information

No response

wangqiang9 avatar May 19 '23 11:05 wangqiang9

@XDUWQ , you must pass a list of class_prior_prompt if you set prior_loss_weight larger than 0.

LeoXing1996 avatar May 29 '23 05:05 LeoXing1996

Close this issue since there has been no response. Please feel free to reopen it if needed.

LeoXing1996 avatar Jul 18 '23 08:07 LeoXing1996