denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Can diffusion model be used into image to image translation ?

Open lianggaoquan opened this issue 2 years ago • 11 comments

Can diffusion model be used into image to image translation ?

lianggaoquan avatar Sep 26 '22 14:09 lianggaoquan

Yes, just concatenate your image to the noised-image input and change the input-channel size.

robert-graf avatar Sep 27 '22 07:09 robert-graf

@lianggaoquan yea, what Robert said

i can add it later this week

lucidrains avatar Sep 27 '22 18:09 lucidrains

That depends | would say for paired i2i you can do what @robert-graf mentioned however if you for example have segmentation maps as one pair you might be better of adding a SPADE normalization layer into your UNet and don't attach the segmentation map as input.

However for unpaired i2i I think this current framework most likely will not work as I can't see how the current training signal would be enough but maybe I am wrong

Mut1nyJD avatar Oct 22 '22 11:10 Mut1nyJD

Hi, any update for the paired image translation in the repo? Or can anyone show at least snippet of code in order to modify the repo to do the work? Anyway, really appreciate all the works, learn a lot!

FireWallDragonDarkFluid avatar Feb 12 '23 06:02 FireWallDragonDarkFluid

@robert-graf Where exactly should I perform concatenation operation? Could you please give more details? I tried to do it very beginning of the Unet forward, but did not work.

Yes, just concatenate your image to the noised-image input and change the input-channel size.

huseyin-karaca avatar Aug 18 '23 21:08 huseyin-karaca

@huseyin-karaca This Google paper introduced this https://iterative-refinement.github.io/palette/.

I did it before the forward call of the U-Net and only updated the input size of the first Con-Block.

# Conditional p(x_0| y) -> p(x_0)*p(y|x_0) --> just added it to the input
if not x_conditional is None and self.opt.conditional:
    x = torch.cat([x, x_conditional], dim=1)
# --------------

Here is the rest for context my Image2Image Code under /img2img2D/diffusion.py. I hope lucidrains is fine with linking my Code here. If you are looking for the paper referenced, the preprint is coming out on Tuesday.

robert-graf avatar Aug 19 '23 11:08 robert-graf

@robert-graf Thank you for your kind reply!

huseyin-karaca avatar Aug 19 '23 16:08 huseyin-karaca

Hi, so to do i2i using this repo, is it okay to use the Unet self_condition=True, or we have to do the cat manually and change in another place?

heitorrapela avatar Dec 26 '23 04:12 heitorrapela

@heitorrapela You would have to manually change the code written in this repo to achieve i2i. The self_condition=True in the Unet from this repo is the implementation of this paper: https://arxiv.org/abs/2208.04202

By the way, diffusion model often achieve better results from pre-trained model when applying to i2i, maybe you could take a look at HuggingFace's diffusers: https://github.com/huggingface/diffusers

FireWallDragonDarkFluid avatar Dec 26 '23 06:12 FireWallDragonDarkFluid

@FireWallDragonDarkFluid, thanks for the response. I was trying with the self_condition, but yes, it was not what I wanted, and in the end, it was still adding artifacts to the translation process.

I will see if I can implement myself with this library or the diffusers. Using diffusers, I just tried simple things, but I still need to train, so I must investigate. Due to my task restrictions, I also cannot use a heavy model, such as SD.

heitorrapela avatar Dec 26 '23 14:12 heitorrapela

I did a quick implementation, but I am not 100% sure; I am training some models with it; here are my modifications if anyone wants to try also:

  • I am using ddim (sampling_timesteps < timesteps).
  • I updated the UNet channels to be 2*input_channels. e.g. Unet(dim = 64,dim_mults = (1, 2, 4, 8), flash_attn = False,channels=6).
  • Before line 794: model_out = self.model(x, t, x_self_cond)), I added x = torch.cat([x, x_start], dim=1)
  • Here is the workaround to make the code work (for the loss when forwarding the images), before L806, add: target = torch.cat([target, x_start], dim=1).
  • Finally, when sampling, I slice the three initial channels corresponding to my sampled image without the initial image.

heitorrapela avatar Dec 26 '23 17:12 heitorrapela