ControlNet
ControlNet copied to clipboard
Inference (Test)
How to do inference on unseen images during the training phase?
@lllyasviel if you have any snippets of code on how you run inference on unseen images I can take a stab at adding validation during model training. Thanks!
same question
Hi @bekhzod-olimov , @pfdamasceno
has any of you guys managed to do inference after training the model? I can't find any script to use for post-training inference.
Thanks in advance
Hi all :) has anyone worked on an inference script so far?
Hi all, I've hacked together an inference function run_sampler()
that you can use for sampling. It's based on the scribble2image gradio demo from the original authors (@lllyasviel et al.) I wanted to put this forward as a community resource, since a lot of people seem to be asking after it.
https://github.com/gabegrand/ControlNet/blob/e7e89975bf4dd0f136681a777986d96c883112c9/project/inference.py
Hi all, I've hacked together an inference function
run_sampler()
that you can use for sampling. It's based on the scribble2image gradio demo from the original authors (@lllyasviel et al.) I wanted to put this forward as a community resource, since a lot of people seem to be asking after it.https://github.com/gabegrand/ControlNet/blob/e7e89975bf4dd0f136681a777986d96c883112c9/project/inference.py Hi @gabegrand!
Is this script only for the inference after the training phase or during the training phase?
Looking forward to hearing from you. Thanks!
I wrote an inference script with small code modifications. I hope my experience helps...
First, change two parts of the training code tutorial_train.py.
resume_path = "YOUR_OWN_CKPT"
#trainer.fit(model, dataloader)
trainer.test(model, dataloaders=dataloader)
pytorch lightning performs trainer.test
with the test_step
function declared in model.
The class structure of the model used in controlnet is as follows.
ControlLDM class inherit LatentDiffusion class, and LatentDiffusion class inherit DDPM class.
Therefore, let's make the test_step
function in the DDPM , the highest parent class, which is defined at ldm/models/diffusion/ddpm.py.
def shared_step_test(self, batch): #Not used
x = self.get_input(batch, self.first_stage_key)
loss, loss_dict = self(x)
return loss, loss_dict
@torch.no_grad()
def test_step(self, batch, batch_idx):
self.shared_step_test(batch)
The declared shared_step_test
function is not actually used because it is overridden in the child class's function.
LatentDiffusion is defined in ldm/models/diffusion/ddpm.py. Override shared_step_test
as shown below.
The function to save the image was created by referring to the cldm/logger.py function.
def log_local(self, save_dir, split, images, global_step="", current_epoch="", batch_idx=""):
root = os.path.join(save_dir, "image_log", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}.png".format(k)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
if grid.shape[2] == 6: #This part is for when the controlnet receives two conditions as input.
grid1 = grid[:, :, :3]
grid2 = grid[:, :, 3:]
new_grid = np.concatenate((grid1, grid2))
Image.fromarray(new_grid).save(path)
else:
Image.fromarray(grid).save(path)
def shared_step_test(self, batch, **kwargs):
images = self.log_images(batch, split="test_", ddim_steps=50)
for k in images:
N = min(images[k].shape[0], 3)
images[k] = images[k][0:]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
images[k] = torch.clamp(images[k], -1., 1.)
self.log_local("./", "test_", images)
Parameters such as CFG or DDIM step used for inference can be modified through log_images
defined in ControlLDM. Since the log_images
function is overridden by ControlLDM, it is meaningless to modify it in LatentDiffusion. ControlLDM
is declared in cldm/cldm.py.
I wrote an inference script with small code modifications. I hope my experience helps...
First, change two parts of the training code tutorial_train.py.
resume_path = "YOUR_OWN_CKPT" #trainer.fit(model, dataloader) trainer.test(model, dataloaders=dataloader)
pytorch lightning performs
trainer.test
with thetest_step
function declared in model. The class structure of the model used in controlnet is as follows.ControlLDM class inherit LatentDiffusion class, and LatentDiffusion class inherit DDPM class.
Therefore, let's make the
test_step
function in the DDPM , the highest parent class, which is defined at ldm/models/diffusion/ddpm.py.def shared_step_test(self, batch): #Not used x = self.get_input(batch, self.first_stage_key) loss, loss_dict = self(x) return loss, loss_dict @torch.no_grad() def test_step(self, batch, batch_idx): self.shared_step_test(batch)
The declared
shared_step_test
function is not actually used because it is overridden in the child class's function. LatentDiffusion is defined in ldm/models/diffusion/ddpm.py. Overrideshared_step_test
as shown below. The function to save the image was created by referring to the cldm/logger.py function.def log_local(self, save_dir, split, images, global_step="", current_epoch="", batch_idx=""): root = os.path.join(save_dir, "image_log", split) for k in images: grid = torchvision.utils.make_grid(images[k], nrow=4) grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) grid = grid.numpy() grid = (grid * 255).astype(np.uint8) filename = "{}.png".format(k) path = os.path.join(root, filename) os.makedirs(os.path.split(path)[0], exist_ok=True) if grid.shape[2] == 6: #This part is for when the controlnet receives two conditions as input. grid1 = grid[:, :, :3] grid2 = grid[:, :, 3:] new_grid = np.concatenate((grid1, grid2)) Image.fromarray(new_grid).save(path) else: Image.fromarray(grid).save(path) def shared_step_test(self, batch, **kwargs): images = self.log_images(batch, split="test_", ddim_steps=50) for k in images: N = min(images[k].shape[0], 3) images[k] = images[k][0:] if isinstance(images[k], torch.Tensor): images[k] = images[k].detach().cpu() images[k] = torch.clamp(images[k], -1., 1.) self.log_local("./", "test_", images)
Parameters such as CFG or DDIM step used for inference can be modified through
log_images
defined in ControlLDM. Since thelog_images
function is overridden by ControlLDM, it is meaningless to modify it in LatentDiffusion. ControlLDM is declared in cldm/cldm.py.
I tried it according to the method you mentioned, and the following error occurred:
return self.model.test_step(*args, **kwargs)
File "/home/yzli/ControlNet/ldm/models/diffusion/ddpm.py", line 529, in test_step self.shared_step_test(batch) File "/home/yzli/ControlNet/ldm/models/diffusion/ddpm.py", line 524, in shared_step_test loss, loss_dict = self(x) File "/home/yzli/miniconda3/envs/control/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) TypeError: forward() missing 1 required positional argument: 'c'
@Psycho-9
I don't know exactly what your situation is, but based on the error message, it appears that an error occurs in DDPM's forward()
and that this problem is caused by trying to execute shared_step_test()
defined in the DDPM class.
In order for the code I wrote to work properly, the shared_step_test()
function must be called not a function defined in DDPM, but a function overridden in the LatentDiffusion class.
- Is the
shared_step_test()
function defined in the LatentDiffusion class? As written previously, functions must be defined in both DDPM and LatentDiffusion. (The code content is different.) - Is your trained network a DDPM base rather than a ControlLDM base? If you are testing a network trained based on DDPM, you may need to modify the code a bit.
- Did you modify the model itself to suit your own purposes? In this case, I think it will be difficult for me to help you.
@gabegrand Hello. The run_sampler function in your Github repository (https://github.com/gabegrand/ControlNet/blob/main/project/inference.py) is very helpful to me. However, the effect of my experiment on fill50k dataset was not so good, the result was greatly affected by the prompt, and the picture source had almost no effect, which I guess may be caused by insufficient of training time. Is it convenient to ask how many epochs you trained? I am looking forward to your reply and express my thanks again!