mmagic
mmagic copied to clipboard
[Bug] It seems that there is a logic error in the Dreambooth generate_class_prior_images function
Prerequisite
- [X] I have searched Issues and Discussions but cannot get the expected help.
- [X] I have read the FAQ documentation but cannot get the expected help.
- [X] The bug has not been fixed in the latest version (main) or latest version (0.x).
Task
I'm using the official example scripts/configs for the officially supported tasks/models/datasets.
Branch
main branch https://github.com/open-mmlab/mmagic
Environment
Environment is correct.
Reproduces the problem - code sample
In the class Dreambooth, prior_loss_weight is default to set 0. This way the program does work.
However, if it is changed to 1 or other non-zero values, the following error will be reported:
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/tools/train.py:114 in <module> │
│ │
│ 111 │
│ 112 │
│ 113 if __name__ == '__main__': │
│ ❱ 114 │ main() │
│ 115 │
│ │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/tools/train.py:107 in main │
│ │
│ 104 │ print_colored_log(f'Log directory: {runner._log_dir}') │
│ 105 │ │
│ 106 │ # start training │
│ ❱ 107 │ runner.train() │
│ 108 │ │
│ 109 │ print_colored_log(f'Log saved under {runner._log_dir}') │
│ 110 │ print_colored_log(f'Checkpoint saved under {cfg.work_dir}') │
│ │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/runner.py:1721 in train │
│ │
│ 1718 │ │ # This must be called **AFTER** model has been wrapped. │
│ 1719 │ │ self._maybe_compile('train_step') │
│ 1720 │ │ │
│ ❱ 1721 │ │ model = self.train_loop.run() # type: ignore │
│ 1722 │ │ self.call_hook('after_run') │
│ 1723 │ │ return model │
│ 1724 │
│ │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/loops.py:278 in run │
│ │
│ 275 │ │ │ self.runner.model.train() │
│ 276 │ │ │ │
│ 277 │ │ │ data_batch = next(self.dataloader_iterator) │
│ ❱ 278 │ │ │ self.run_iter(data_batch) │
│ 279 │ │ │ │
│ 280 │ │ │ self._decide_current_val_interval() │
│ 281 │ │ │ if (self.runner.val_loop is not None │
│ │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/runner/loops.py:302 in run_iter │
│ │
│ 299 │ │ # synchronization during gradient accumulation process. │
│ 300 │ │ # outputs should be a dict of loss. │
│ 301 │ │ outputs = self.runner.model.train_step( │
│ ❱ 302 │ │ │ data_batch, optim_wrapper=self.runner.optim_wrapper) │
│ 303 │ │ │
│ 304 │ │ self.runner.call_hook( │
│ 305 │ │ │ 'after_train_iter', │
│ │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/mmengine/model/wrappers/seperate_distributed. │
│ py:102 in train_step │
│ │
│ 99 │ │ Returns: │
│ 100 │ │ │ Dict[str, torch.Tensor]: A dict of tensor for logging. │
│ 101 │ │ """ │
│ ❱ 102 │ │ return self.module.train_step(data, optim_wrapper) │
│ 103 │ │
│ 104 │ def val_step(self, data: Union[dict, tuple, list]) -> list: │
│ 105 │ │ """Gets the prediction of module during validation process. │
│ │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/mmagic/models/editors/dreambooth/dreambooth.py │
│ :253 in train_step │
│ │
│ 250 │ │ │ #print(f"=======self.prior_loss_weight: {self.prior_loss_weight}") │
│ 251 │ │ │ if self.prior_loss_weight != 0: │
│ 252 │ │ │ │ # image and prompt for prior preservation │
│ ❱ 253 │ │ │ │ self.generate_class_prior_images(num_batches=num_batches) │
│ 254 │ │ │ │ class_images_used = [] │
│ 255 │ │ │ │ for _ in range(num_batches): │
│ 256 │ │ │ │ │ idx = random.randint(0, len(self.class_images) - 1) │
│ │
│ /opt/conda/envs/mmagic/lib/python3.7/site-packages/torch/autograd/grad_mode.py:27 in │
│ decorate_context │
│ │
│ 24 │ │ @functools.wraps(func) │
│ 25 │ │ def decorate_context(*args, **kwargs): │
│ 26 │ │ │ with self.clone(): │
│ ❱ 27 │ │ │ │ return func(*args, **kwargs) │
│ 28 │ │ return cast(F, decorate_context) │
│ 29 │ │
│ 30 │ def _wrap_generator(self, func): │
│ │
│ /mnt/user/E-yijing.wq-401594/github/mmagic/mmagic/mmagic/models/editors/dreambooth/dreambooth.py │
│ :124 in generate_class_prior_images │
│ │
│ 121 │ │ │
│ 122 │ │ print("=============self.class_prior_prompt: ", self.class_prior_prompt) │
│ 123 │ │ assert self.class_prior_prompt is not None, ( │
│ ❱ 124 │ │ │ '\'class_prior_prompt\' must be set when \'prior_loss_weight\' is ' │
│ 125 │ │ │ 'larger than 0.') │
│ 126 │ │ assert self.num_class_images is not None, ( │
│ 127 │ │ │ '\'num_class_images\' must be set when \'prior_loss_weight\' is ' │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
AssertionError: 'class_prior_prompt' must be set when 'prior_loss_weight' is larger than 0.
I think there is a logic error in the def generate_class_prior_images function, which fires when the prior_loss_weight parameter is not equal to 0.
Reproduces the problem - command or script
https://github.com/open-mmlab/mmagic/blob/main/configs/dreambooth/README.md
Reproduces the problem - error message
ditto.
Additional information
No response
@XDUWQ , you must pass a list of class_prior_prompt if you set prior_loss_weight larger than 0.
Close this issue since there has been no response. Please feel free to reopen it if needed.