denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Use diffusion for non-squared images

Open modantailleur opened this issue 1 year ago • 1 comments

I'd like to use your diffusion library on non-squared images (of size 64x128 for example). As far as I understood your code, right now it only takes into account square images (of size 128x128 for example). On my computer, I've modified a little bit the code of denoising_diffusion_pytorch.py so that it would take into account non-square images when I create an instance of GaussianDiffusion. So now I can use it like this:

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8)
)

diffusion = GaussianDiffusion(
    model,
    image_size = (64, 128) ,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

As I'm new to diffusion models, do you know if this would cause any theoretical issue in your code with those modifications ? The only things I've modified are the methods sample and forward of GaussianDiffusion, just to replace image_size by image_size[0] or image_size[1]:


    def sample(self, batch_size = 16, return_all_timesteps = False):
        image_size, channels = self.image_size, self.channels
        sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
        return sample_fn((batch_size, channels, image_size[0], image_size[1]), return_all_timesteps = return_all_timesteps)

...

    def forward(self, img, *args, **kwargs):
        b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
        assert h == img_size[0] and w == img_size[1], f'height and width of image must be {img_size}'
        t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

        img = self.normalize(img)
        return self.p_losses(img, t, *args, **kwargs)

modantailleur avatar Mar 09 '23 16:03 modantailleur

Yes that's correct modification. There is no other and there are no other issues. Been training non-square diffusion models for months now.

Mut1nyJD avatar Mar 18 '23 16:03 Mut1nyJD