latent-diffusion icon indicating copy to clipboard operation
latent-diffusion copied to clipboard

[Reproduction issue] Semantic image synthesis and layout-to-image cannot be reproduced

Open gene-rative opened this issue 3 years ago • 49 comments
trafficstars

Can you provide inference scripts for semantic image synthesis and layout-to-image synthesis? I tried to use data loaders from the taming-transformers repo but got random noise outputs. The evaluation results are far from those reported in the paper. Thanks!

gene-rative avatar Aug 12 '22 04:08 gene-rative

same question

wangqiang9 avatar Aug 22 '22 03:08 wangqiang9

Thank you very much for publishing your excellent research results. I am also interested in reproducing the layout-to-image model as well. Is there any reproduction code available? Thank you in advance for your consideration.

shunk031 avatar Aug 26 '22 07:08 shunk031

Also waiting for the release of the pretrained layout-to-image model trained from scratch on COCO and the dataet code. Thanks !!

ZGCTroy avatar Sep 12 '22 04:09 ZGCTroy

Also waiting for the semantic synthesis training pipeline

Feanor007 avatar Oct 12 '22 21:10 Feanor007

Hi,

I managed to train the semantic image synthesis model. I first collected the flickr data according to readme from taming-transformers repo, and used sflckr.py as training dataset.

Then, I wrote the yaml config file according to yaml config file:

model:
  base_learning_rate: 1.0e-06
  target: ldm.models.diffusion.ddpm.LatentDiffusion
  params:
    linear_start: 0.0015
    linear_end: 0.0205
    log_every_t: 100
    timesteps: 1000
    loss_type: l1
    first_stage_key: image
    cond_stage_key: segmentation
    image_size: 64
    channels: 3
    concat_mode: true
    cond_stage_trainable: true

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 10000 ]
        cycle_lengths: [ 10000000000000 ]
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 64
        in_channels: 6
        out_channels: 3
        model_channels: 128
        attention_resolutions:
        - 32
        - 16
        - 8
        num_res_blocks: 2
        channel_mult:
        - 1
        - 4
        - 8
        num_heads: 8

    first_stage_config:
      target: ldm.models.autoencoder.VQModelInterface
      params:
        embed_dim: 3
        n_embed: 8192
        ckpt_path: models/first_stage_models/vq-f4/model.ckpt
        ddconfig:
          double_z: false
          z_channels: 3
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.SpatialRescaler
      params:
        n_stages: 2
        in_channels: 182
        out_channels: 3

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 12
    num_workers: 5
    wrap: False
    train:
      target: ldm.data.flickr.FlickrSegTrain  #  PUT YOUR DATASET 
      params:
        size: 256
    validation:
      target: ldm.data.flickr.FlickrSegEval  #  PUT YOUR DATASET 
      params:
        size: 256

lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 5000
        max_images: 8
        increase_log_steps: False


  trainer:
    benchmark: True

And last, I followed python main.py --base <config_above>.yaml -t --gpus 0, to train the model.

It did work. Here is a result coming from my training process:

conditions original_conditioning_gs-045000_e-000082_b-000044 samples samples_gs-045000_e-000082_b-000044

By the way, I find that the config yaml file doesn't load ckpt at first stage config

first_stage_config:
  target: ldm.models.autoencoder.VQModelInterface
  params:
    embed_dim: 3
    n_embed: 8192
    ckpt_path: models/first_stage_models/vq-f4/model.ckpt  # this line is missing 
    ddconfig:
      double_z: false

I wonder whether this is the reason for failing in inference.

otamic avatar Oct 15 '22 02:10 otamic

@otamic I saw your fantastic results. I am struggling with how to inference (test) by the pretrained model to generate landscape images from segmentation images. Could you share your code to inference (test) if you could?

YorkNishi999 avatar Oct 23 '22 21:10 YorkNishi999

@YorkNishi999

This's my inference code, which mostly comes from the log images in ddpm

import torch
import numpy as np

from scripts.sample_diffusion import load_model
from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from einops import rearrange

from ldm.data.flickr import FlickrSegEval


def ldm_cond_sample(config_path, ckpt_path, dataset, batch_size):
    config = OmegaConf.load(config_path)
    model, _ = load_model(config, ckpt_path, None, None)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    x = next(iter(dataloader))

    seg = x['segmentation']

    with torch.no_grad():
        seg = rearrange(seg, 'b h w c -> b c h w')
        condition = model.to_rgb(seg)

        seg = seg.to('cuda').float()
        seg = model.get_learned_conditioning(seg)

        samples, _ = model.sample_log(cond=seg, batch_size=batch_size, ddim=True,
                                      ddim_steps=200, eta=1.)

        samples = model.decode_first_stage(samples)

    save_image(condition, 'cond.png')
    save_image(samples, 'sample.png')


if __name__ == '__main__':

    config_path = 'models\ldm\semantic_synthesis256\config.yaml'
    ckpt_path = 'models\ldm\semantic_synthesis256\model.ckpt'

    dataset = FlickrSegEval(size=256)

    ldm_cond_sample(config_path, ckpt_path, dataset, 4)

Note that there's one line missing as I descriped above in the config file.

I simply picked up some segmentations from the dataset to generate images, where you may want to make some changes to suit your needs.

otamic avatar Oct 24 '22 01:10 otamic

@otamic I am very grateful for you to share your code!

I used your code and generated the images but it is low quality. I would make sure that you train the model from ckpt_path = 'models\ldm\semantic_synthesis256\model.ckpt' at first, then, you inference (generate) the images from semantic images. Am I correct?

My generated image is here:

image

YorkNishi999 avatar Oct 24 '22 07:10 YorkNishi999

@YorkNishi999

In fact, models\ldm\semantic_synthesis256\model.ckpt refers to the pretrained model downloaded from Pretrained LDMs when I wrote this code.

To test your own trained model, just change the path to something like logs/xxxx/checkpoints/last.ckpt after a training process. (So you are right.)

This's a result tested on the downloaded model:

condition cond sample sample

And my trained model:

condtion my_cond sample my_sample

It works fine here. So I wonder whether you just haven't trained your model long enough.

otamic avatar Oct 24 '22 07:10 otamic

Works perfectly on my side, thanks @otamic !

JeanJulesBigeard avatar Oct 24 '22 08:10 JeanJulesBigeard

@otamic Thank you for sharing your experiments!! I will retry it with some training..

YorkNishi999 avatar Oct 25 '22 05:10 YorkNishi999

@otamic I got the fine results after looking for my bugs (it is my fault).

Thank you again for your kindness!

image

YorkNishi999 avatar Oct 25 '22 06:10 YorkNishi999

@otamic Wow . that's nice. Can u share your dataloade code ? I want be sure about something. I will write my own :D

SerdarHelli avatar Oct 25 '22 21:10 SerdarHelli

@SerdarHelli

I think you mean the dataset class in the config file:

data:
  ...
  params:
    ...
    train:
      target: ldm.data.flickr.FlickrSegTrain  #  PUT YOUR DATASET 
      ...
    validation:
      target: ldm.data.flickr.FlickrSegEval  #  PUT YOUR DATASET 
      ...

If so, I used the code from sflckr.py as described above. There is a Examples class in the script:

class Examples(SegmentationBase):
    def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
        super().__init__(data_csv="data/sflckr_examples.txt",
                         data_root="data/sflckr_images",
                         segmentation_root="data/sflckr_segmentations",
                         size=size, random_crop=random_crop, interpolation=interpolation)

And I added my dataset referring to my own data (collected according to this) like that:

class FlickrSegTrain(SegmentationBase):
    def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
        super().__init__(data_csv='data/flickr/flickr_train.txt',
                         data_root='data/flickr/flickr_images',
                         segmentation_root='data/flickr/flickr_segmentations',
                         size=size, random_crop=random_crop, interpolation=interpolation)


class FlickrSegEval(SegmentationBase):
    def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
        super().__init__(data_csv='data/flickr/flickr_eval.txt',
                         data_root='data/flickr/flickr_images',
                         segmentation_root='data/flickr/flickr_segmentations',
                         size=size, random_crop=random_crop, interpolation=interpolation)

That's all I have done. (It's only very few changes, so I didn't post on it.)

To this point, I believe I have written everything needed to reproduce the semantic synthesis result.

otamic avatar Oct 26 '22 02:10 otamic

Yes thanks @otamic https://github.com/CompVis/taming-transformers/blob/master/taming/data/sflckr.py I was searching this one actually. I know they wrote, but I didnt check out it :D

SerdarHelli avatar Oct 26 '22 12:10 SerdarHelli

@otamic I have trained semantic synthesis 255 on cityscapes with the same config you have share, but I m sample getting this image as a result, do you have any ideas why it can happen?

mmash98 avatar Oct 26 '22 14:10 mmash98

@otamic I have trained semantic synthesis 255 on cityscapes with the same config you have share, but I m

sample getting this image as a result, do you have any ideas why it can happen?

I think You should check out your config . For example , is your last condition stage input channel 182 ? How many labels you have for your dataset ?

SerdarHelli avatar Oct 26 '22 14:10 SerdarHelli

@otamic I have trained semantic synthesis 255 on cityscapes with the same config you have share, but I m

sample getting this image as a result, do you have any ideas why it can happen?

I think You should check out your config . For example , is your last condition stage input channel 182 ? How many labels you have for your dataset ? @SerdarHelli I have changed it as well, in my case its is 35

I see , did you check out your batches ? I dont know , maybe you trained not enough .

SerdarHelli avatar Oct 26 '22 14:10 SerdarHelli

@mmash98

Could you try a smaller batch size, such as 4? If it can't help, I have no other idea.

otamic avatar Oct 26 '22 15:10 otamic

@mmash98

Could you try a smaller batch size, such as 4? If it can't help, I have no other idea.

I think . He didnt train enough. At 5k steps, I am getting same results.

SerdarHelli avatar Oct 30 '22 09:10 SerdarHelli

Guys, In addition , should we train our vqgan ? I think we should train vqgan with our own data , if our domain is very different.

Edit : I am gettiing worse results with ldm+vqf4 than Gan for semantic image synthesis . Probably , I should train more . Or my data is very limited for ldm . Maybe on the limited data , ldm is not good

Also , you can train on the colab . I can share code.

SerdarHelli avatar Oct 30 '22 09:10 SerdarHelli

@otamic Hey, may I ask a question? I follow your yaml and inference.py to training images with deepfashion which semantic has 24 categories, and I change 182 to 24. But my results is strange as shown below. I want to know is there any other things to notice or what I did wrong? Looking forward to your reply, thx so much! 截屏2022-11-05 下午9 24 38

Kai-0515 avatar Nov 05 '22 13:11 Kai-0515

@Kai-0515

I think you didn't successfully load the pretrained first stage model. Check that the missing line I mentioned is added, and make sure there is the ckpt file. I actually did have similar results, which is how I found the missing line.

otamic avatar Nov 05 '22 14:11 otamic

@otamic You'r right! Thx very much for your quick reply!

Kai-0515 avatar Nov 06 '22 03:11 Kai-0515

image

Unlike the gan methods, the condition is converted to RGB image in ldm. So, your categories must be correct, otherwise you will give wrong cond. Also , you must be sure about autoencoder (vq, kl) .

SerdarHelli avatar Nov 07 '22 11:11 SerdarHelli

Has anybody trained a model for the layout2image task yet? I'm not quiet sure how my Bounding boxes input is supposed to look like. Andy what a prpoer configuration would be? Thank you so much for any inputs. I know the layout2img-openimages256 config exists, but I'm not sure how the input is supposed to be.

mauerflitzer avatar Nov 14 '22 08:11 mauerflitzer

@otamic do I understand it correctly that you train everything from scratch, the whole model except for the vq-f4? Is it also possible to skip training the unet and vae and only train the conditioning part?

mauerflitzer avatar Nov 14 '22 09:11 mauerflitzer

@mauerflitzer

You are correct about my training. In my opinion, I think only training the conditioning part is impossible at LDM. First, how to supervise this training? Second, the unet structures of conditional and unconditional model are different. In this case, the number of channels at the unet input is doubled when conditioned. But it sounds like the Classifier Guided Diffusion in another conditional case.

otamic avatar Nov 14 '22 14:11 otamic

@otamic I thought about freezing the unet and vae weights and taking a released checkpoint of 1.4 or maybe 1.5 and then swap out the conditioning part for the new one and start training on that.

mauerflitzer avatar Nov 14 '22 14:11 mauerflitzer

@mauerflitzer

Sorry, I don't understand what you mean the ckeckpoint of 1.4 or 1.5. If the conditioning parts(τ_θ) work in the same way, I think you can just try it. Although I intuitively think it might not work.

otamic avatar Nov 15 '22 14:11 otamic