DiffBIR
DiffBIR copied to clipboard
Question on training Stage2 with Real-ESRGAN degradation
Hi, thanks for sharing the great work!
I tried to follow the training process, but faced problems during training in the Stage 2 model.
After filling the train_cldm.yaml
file, I run the python train.py --config configs/train_cldm.yaml
, but got the below error:
Traceback (most recent call last):
File "/home/jaeha/Research/DiffBIR/train.py", line 32, in <module>
main()
File "/home/jaeha/Research/DiffBIR/train.py", line 28, in main
trainer.fit(model, datamodule=data_module)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
self._run(model)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
self._dispatch()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
self.accelerator.start_training(self)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
self.training_type_plugin.start_training(trainer)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
self._results = trainer.run_stage()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
return self._run_train()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
self.trainer.call_hook(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
trainer_hook(*args, **kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 379, in log_images
samples = self.sample_log(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 394, in sample_log
samples = sampler.sample(
File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
TypeError: sample() got an unexpected keyword argument 'unconditional_guidance_scale'
I suspect the error occurs from here: https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L394
where the function sample in SpacedSampler does not require unconditional_guidance_scale
as input components.
Could you please let me know the solution for this symptom?