MultiDiffusion icon indicating copy to clipboard operation
MultiDiffusion copied to clipboard

region based not working for multiple prompts

Open Mao718 opened this issue 1 year ago • 4 comments

Hello. I ran into a problem, can anyone help me on this. Here's the code I run

device = torch.device('cuda')
sd = MultiDiffusion(device)


mask = torch.zeros(2,1,512,512).cuda()
mask[0,:,:256]=1
mask[1,:,256:]=1

fg_masks = mask
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])

prompts = ['dog' ,'cat']# + ['artifacts' ] ,'cat'
#neg_prompts = [opt.bg_negative] + opt.fg_negative
print(masks.shape , len(prompts))
img = sd.generate(masks, prompts , '' , width = 512 )

It gave the following error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 17
     15 #neg_prompts = [opt.bg_negative] + opt.fg_negative
     16 print(masks.shape , len(prompts))
---> 17 img = sd.generate(masks, prompts , '' , width = 512 )

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/Desktop/project/MultiDiffusion/region_based.py:142, in MultiDiffusion.generate(self, masks, prompts, negative_prompts, height, width, num_inference_steps, guidance_scale, bootstrapping)
    139     bg = self.scheduler.add_noise(bg, noise[:, :, h_start:h_end, w_start:w_end], t)
    140     #print(latent.shape , 'latent')
    141     #print(latent_view.shape ,bg.shape,masks_view.shape)
--> 142     latent_view[1:] = latent_view[1:] * masks_view[1:] + bg * (1 - masks_view[1:])
    144 # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    145 latent_model_input = torch.cat([latent_view] * 2)

RuntimeError: The expanded size of the tensor (1) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [1, 4, 64, 64].  Tensor sizes: [2, 4, 64, 64]

Thank you.

Mao718 avatar Oct 04 '23 07:10 Mao718