DiffBIR icon indicating copy to clipboard operation
DiffBIR copied to clipboard

Question on training Stage2 with Real-ESRGAN degradation

Open JaehaKim97 opened this issue 1 year ago • 4 comments

Hi, thanks for sharing the great work!

I tried to follow the training process, but faced problems during training in the Stage 2 model.

After filling the train_cldm.yaml file, I run the python train.py --config configs/train_cldm.yaml, but got the below error:

Traceback (most recent call last):
  File "/home/jaeha/Research/DiffBIR/train.py", line 32, in <module>
    main()
  File "/home/jaeha/Research/DiffBIR/train.py", line 28, in main
    trainer.fit(model, datamodule=data_module)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
    self._run(model)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
    self._dispatch()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
    self.accelerator.start_training(self)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
    return self._run_train()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
    self.fit_loop.run()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
    self.trainer.call_hook(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
    callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
    images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 379, in log_images
    samples = self.sample_log(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 394, in sample_log
    samples = sampler.sample(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
TypeError: sample() got an unexpected keyword argument 'unconditional_guidance_scale'

I suspect the error occurs from here: https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L394

where the function sample in SpacedSampler does not require unconditional_guidance_scale as input components.

Could you please let me know the solution for this symptom?

JaehaKim97 avatar Dec 27 '23 12:12 JaehaKim97