stable-diffusion icon indicating copy to clipboard operation
stable-diffusion copied to clipboard

How to train a image-condition model with my custom dataset?

Open pokameng opened this issue 2 years ago • 34 comments

Hello @justinpinkney I want to email to you but failed. Thank for your great work in image-condition model, and I am very interesting in it. Can you give me an example about how to train a image-condition model with my custom dataset?? I have found a config sd-image.yaml, but it seems to train on lion dataset.

image

pokameng avatar Dec 05 '22 03:12 pokameng

I'm also interested in doing this. Would be keen to discuss.

@pokameng have you made any progress?

BenjaminIrwin avatar Dec 07 '22 13:12 BenjaminIrwin

I'm also interested in doing this. Would be keen to discuss.

@pokameng have you made any progress?

no I have on idea about it.

But we can discuss about it

pokameng avatar Dec 08 '22 04:12 pokameng

Ok nice. What have you tried so far?

BenjaminIrwin avatar Dec 08 '22 09:12 BenjaminIrwin

Ok nice. What have you tried so far?

Hello, may be we can chat with wechat or twiter or other chatting style. I just using FrozenImageCilpEmbedding mudule, and i do not have any progress

pokameng avatar Dec 08 '22 12:12 pokameng

Ok nice. What have you tried so far?

@BenjaminIrwin

pokameng avatar Dec 08 '22 12:12 pokameng

Sure, sounds good! Find me on twitter @bentdirwin

BenjaminIrwin avatar Dec 08 '22 16:12 BenjaminIrwin

Hey for local data try taking the image variations config and then replace the data part with something like this:

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 2
    num_workers: 2
    num_val_workers: 0 # Avoid a weird val dataloader issue
    train:
      target: ldm.data.simple.FolderData
      params:
        root_dir: /data/ffhq/images1024x1024
        ext: png
        image_transforms:
        - target: torchvision.transforms.Resize
          params:
            size: 512
            interpolation: 3
        - target: torchvision.transforms.RandomHorizontalFlip
    validation:
      target: ldm.data.simple.FolderData
      params:
        root_dir: /data/celeba1000/HQ
        ext: jpg
        image_transforms:
        - target: torchvision.transforms.Resize
          params:
            size: 512
            interpolation: 3

The above is assuming square images, but you can add a random or centre crop to the image_transforms if yours aren't square already

Depending what you're doing you might want to avoid doing validation which you can do with:

trainer:
    check_val_every_n_epoch: 100000

hope that helps. If you have issues feel free to post your config and any error. I'm curious what sort of dataset you're planning on using if you're willing to share details!

justinpinkney avatar Dec 09 '22 14:12 justinpinkney

Hey for local data try taking the image variations config and then replace the data part with something like this:

data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 2
    num_workers: 2
    num_val_workers: 0 # Avoid a weird val dataloader issue
    train:
      target: ldm.data.simple.FolderData
      params:
        root_dir: /data/ffhq/images1024x1024
        ext: png
        image_transforms:
        - target: torchvision.transforms.Resize
          params:
            size: 512
            interpolation: 3
        - target: torchvision.transforms.RandomHorizontalFlip
    validation:
      target: ldm.data.simple.FolderData
      params:
        root_dir: /data/celeba1000/HQ
        ext: jpg
        image_transforms:
        - target: torchvision.transforms.Resize
          params:
            size: 512
            interpolation: 3

The above is assuming square images, but you can add a random or centre crop to the image_transforms if yours aren't square already

Depending what you're doing you might want to avoid doing validation which you can do with:

trainer:
    check_val_every_n_epoch: 100000

hope that helps. If you have issues feel free to post your config and any error. I'm curious what sort of dataset you're planning on using if you're willing to share details!

Thanks! I have run the program, but something was wrong : the condition image is : image the input image is: image but the reconstruction_gs is image andsamplesis : image Why i have this problem?

This is my config: `model: base_learning_rate: 1.0e-04 target: ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image cond_stage_key: image image_size: 64 channels: 4 cond_stage_trainable: false # Note: different from the one we trained before conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215

scheduler_config: # 10000 warmup steps
  target: ldm.lr_scheduler.LambdaLinearScheduler
  params:
    warm_up_steps: [ 1000 ]
    cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
    f_start: [ 1.e-6 ]
    f_max: [ 1. ]
    f_min: [ 1. ]

unet_config:
  target: ldm.modules.diffusionmodules.openaimodel.UNetModel
  params:
    image_size: 32 # unused
    in_channels: 4
    out_channels: 4
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

first_stage_config:
  target:  ldm.models.autoencoder.AutoencoderKL #ldm.models.autoencoder.VQModelInterface    #ldm.models.autoencoder.AutoencoderKL
  params:
    embed_dim: 4
    monitor: val/rec_loss
    ddconfig:
      double_z: true
      z_channels: 4
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity

cond_stage_config:
  target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder

data: target: main.DataModuleFromConfig params: batch_size: 2 num_workers: 2 num_val_workers: 0 # Avoid a weird val dataloader issue train: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.RandomHorizontalFlip validation: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3

lightning: find_unused_parameters: false modelcheckpoint: params: every_n_train_steps: 5000 callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 1000 max_images: 8 increase_log_steps: False log_first_step: True log_images_kwargs: use_ema_scope: False inpaint: False plot_progressive_rows: False plot_diffusion_rows: False N: 8 unconditional_guidance_scale: 3.0 unconditional_guidance_label: [""]

trainer: benchmark: True # val_check_interval: 5000000 # really sorry num_sanity_val_steps: 0 accumulate_grad_batches: 1 ` My dataset is fanghua. @justinpinkney

pokameng avatar Dec 11 '22 06:12 pokameng

And I want to know, where is the condition image feeded into and where is the input feeded into? What are the requirements for condition image? @justinpinkney

pokameng avatar Dec 11 '22 06:12 pokameng

f you have issues feel free to post your config and any error. I'm curious what sort of dataset you're planning on using if you're willing to share details!

I want to set RGB image as my image condition and gray image as inputs, so how can i set the config.yaml?

pokameng avatar Dec 11 '22 12:12 pokameng

@justinpinkney hello bro

pokameng avatar Dec 17 '22 14:12 pokameng

Sorry, I'm paying attention now. Remind me what you r issue is on this? The reason for the reconstructions looking like total noise are probably because the vae weights aren't being loaded, try specifying the ckpt_path to load for these like this: https://github.com/justinpinkney/stable-diffusion/blob/main/configs/stable-diffusion/pokemon.yaml#L46

justinpinkney avatar Feb 12 '23 21:02 justinpinkney

@justinpinkney hello,I have reloaded the vae weights, but the results are also like this : condition input: image input: image reconstruction: image samples: image These images above are sampling after 600 iters

This my config: `model: base_learning_rate: 1.0e-04 target: ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image cond_stage_key: image_rgb image_size: 64 channels: 4 cond_stage_trainable: False # Note: different from the one we trained before conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215

scheduler_config: # 10000 warmup steps
  target: ldm.lr_scheduler.LambdaLinearScheduler
  params:
    warm_up_steps: [ 1000 ]
    cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
    f_start: [ 1.e-6 ]
    f_max: [ 1. ]
    f_min: [ 1. ]

unet_config:
  target: ldm.modules.diffusionmodules.openaimodel.UNetModel
  params:
    image_size: 32 # unused
    in_channels: 4
    out_channels: 4
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

first_stage_config:
  target:  ldm.models.autoencoder.AutoencoderKL #ldm.models.autoencoder.VQModelInterface    #ldm.models.autoencoder.AutoencoderKL
  ckpt_path: models/first_stage_models/kl-f8/model.ckpt
  params:
    embed_dim: 4
    monitor: val/rec_loss
    ddconfig:
      double_z: true
      z_channels: 4
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity

cond_stage_config:
  target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder

data: target: main.DataModuleFromConfig params: batch_size: 3 num_workers: 6 num_val_workers: 0 # Avoid a weird val dataloader issue train: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.RandomHorizontalFlip validation: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3

lightning: find_unused_parameters: false modelcheckpoint: params: every_n_train_steps: 500 callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 100 max_images: 4 increase_log_steps: False log_first_step: True log_images_kwargs: use_ema_scope: False inpaint: False plot_progressive_rows: False plot_diffusion_rows: False N: 8 unconditional_guidance_scale: 3.0 unconditional_guidance_label: [""]

trainer: benchmark: True # val_check_interval: 5000000 # really sorry num_sanity_val_steps: 0 accumulate_grad_batches: 1 `

pokameng avatar Feb 13 '23 00:02 pokameng

@justinpinkney hello the vae checkpoints are these right? image

pokameng avatar Feb 13 '23 02:02 pokameng

now the reconstruction is like this: image samples are like this: image @justinpinkney

pokameng avatar Feb 13 '23 03:02 pokameng

@justinpinkney hello the vae checkpoints are these right? image

Yes those are the zips (and assuming you extracted the checkpoint from inside them) then those should contain the weights. The reconstructions literally only touch the vae encoder/decoder, so if those don't look right something is wrong with that part of the model.

what's the output of ls -lah models/first_stage_models/kl-f8/model.ckpt? is it the right size, might it be corrupted?

justinpinkney avatar Feb 15 '23 09:02 justinpinkney

@justinpinkney hello the output is 1.1G image The size is right, and i do not have any resolution to solve my problem, canyou give me a example configure abut the con_stage_key is image and the conditional input is image? Thanks!!

pokameng avatar Feb 15 '23 14:02 pokameng

In that can I don't know why the vae reconstructions aren't working. What is the command line argument you are using to launch training?

justinpinkney avatar Feb 20 '23 13:02 justinpinkney

This is my command line : CUDA_VISIBLE_DEVICES=7 python main.py --base stable-diffusion/configs/fanghua_cond.yaml -t --gpus 0,

and my config is:

`model: base_learning_rate: 1.0e-04 target: ldm.models.diffusion.ddpm.LatentDiffusion params: linear_start: 0.00085 linear_end: 0.0120 num_timesteps_cond: 1 log_every_t: 200 timesteps: 1000 first_stage_key: image cond_stage_key: image_rgb image_size: 64 channels: 4 cond_stage_trainable: False # Note: different from the one we trained before conditioning_key: crossattn monitor: val/loss_simple_ema scale_factor: 0.18215

scheduler_config: # 10000 warmup steps
  target: ldm.lr_scheduler.LambdaLinearScheduler
  params:
    warm_up_steps: [ 1000 ]
    cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
    f_start: [ 1.e-6 ]
    f_max: [ 1. ]
    f_min: [ 1. ]

unet_config:
  target: ldm.modules.diffusionmodules.openaimodel.UNetModel
  params:
    image_size: 32 # unused
    in_channels: 4
    out_channels: 4
    model_channels: 320
    attention_resolutions: [ 4, 2, 1 ]
    num_res_blocks: 2
    channel_mult: [ 1, 2, 4, 4 ]
    num_heads: 8
    use_spatial_transformer: True
    transformer_depth: 1
    context_dim: 768
    use_checkpoint: True
    legacy: False

first_stage_config:
  target:  ldm.models.autoencoder.AutoencoderKL #ldm.models.autoencoder.VQModelInterface    #ldm.models.autoencoder.AutoencoderKL
  ckpt_path: stable-diffusion-main/models/first_stage_models/kl-f8/model.ckpt
  params:
    embed_dim: 4
    monitor: val/rec_loss
    ddconfig:
      double_z: true
      z_channels: 4
      resolution: 256
      in_channels: 3
      out_ch: 3
      ch: 128
      ch_mult:
      - 1
      - 2
      - 4
      - 4
      num_res_blocks: 2
      attn_resolutions: []
      dropout: 0.0
    lossconfig:
      target: torch.nn.Identity

cond_stage_config:
  target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder

data: target: main.DataModuleFromConfig params: batch_size: 3 num_workers: 6 num_val_workers: 0 # Avoid a weird val dataloader issue train: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3 - target: torchvision.transforms.RandomHorizontalFlip validation: target: ldm.data.simple.FolderData params: root_dir: /home/share/movie_dataset/fanghua/png ext: jpg image_transforms: - target: torchvision.transforms.Resize params: size: 256 interpolation: 3

lightning: find_unused_parameters: false modelcheckpoint: params: every_n_train_steps: 500 callbacks: image_logger: target: main.ImageLogger params: batch_frequency: 100 max_images: 4 increase_log_steps: False log_first_step: True log_images_kwargs: use_ema_scope: False inpaint: False plot_progressive_rows: False plot_diffusion_rows: False N: 8 unconditional_guidance_scale: 3.0 unconditional_guidance_label: [""]

trainer: benchmark: True # val_check_interval: 5000000 # really sorry num_sanity_val_steps: 0 accumulate_grad_batches: 1 `

@justinpinkney

pokameng avatar Feb 21 '23 02:02 pokameng

@justinpinkney hello can i ask you some questions?

pokameng avatar Feb 25 '23 13:02 pokameng

@justinpinkney Do i need to use this ckpt?

sd-clip-vit-l14-img-embed_ema_only.ckpt?

like this:

model: base_learning_rate: 1.0e-04 ckpt_path:sd-clip-vit-l14-img-embed_ema_only.ckpt target: ldm.models.diffusion.ddpm.LatentDiffusion

pokameng avatar Feb 25 '23 14:02 pokameng

There is an extra argument to main.py to load from an existing checkpoint you should use. Let me check

justinpinkney avatar Feb 25 '23 16:02 justinpinkney

Use --finetune_from and pass the checkpoint path. This should load the existing model so the initial set of logs should just look like the output of that model

justinpinkney avatar Feb 25 '23 16:02 justinpinkney

But why the reconstruction is not work ?I want to train a new model @justinpinkney

pokameng avatar Feb 25 '23 16:02 pokameng

I have no solution for it. when i load the weight ckpt the reconstruction is also not work @justinpinkney

pokameng avatar Feb 25 '23 16:02 pokameng

nt to main.py to load from an existing checkpoint you should use. Let me check

大佬,能加下你的微信,一起交流下??

dreamlychina avatar Apr 03 '23 09:04 dreamlychina

可以 我加你

繁华落尽 @.***

 

------------------ 原始邮件 ------------------ 发件人: "justinpinkney/stable-diffusion" @.>; 发送时间: 2023年4月3日(星期一) 下午5:04 @.>; @.@.>; 主题: Re: [justinpinkney/stable-diffusion] How to train a image-condition model with my custom dataset? (Issue #49)

nt to main.py to load from an existing checkpoint you should use. Let me check

大佬,能加下你的微信,一起交流下??

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

pokameng avatar Apr 03 '23 09:04 pokameng

可以 我加你 繁华落尽 @.***   ------------------ 原始邮件 ------------------ 发件人: "justinpinkney/stable-diffusion" @.>; 发送时间: 2023年4月3日(星期一) 下午5:04 @.>; @.@.>; 主题: Re: [justinpinkney/stable-diffusion] How to train a image-condition model with my custom dataset? (Issue #49) nt to main.py to load from an existing checkpoint you should use. Let me check 大佬,能加下你的微信,一起交流下?? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

微信 13739279503

dreamlychina avatar Apr 06 '23 01:04 dreamlychina

可以 我加你 繁华落尽 @.***   ------------------ 原始邮件 ------------------ 发件人: "justinpinkney/stable-diffusion" @.>; 发送时间: 2023年4月3日(星期一) 下午5:04 @.>; @.@.>; 主题: Re: [justinpinkney/stable-diffusion] How to train a image-condition model with my custom dataset? (Issue #49) nt to main.py to load from an existing checkpoint you should use. Let me check 大佬,能加下你的微信,一起交流下?? — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

Hi, may I add your wechat for further discussing? I am also trying to train image condition model on my data. Thanks!

My wechat: wtliao

wtliao avatar Apr 20 '23 14:04 wtliao

@justinpinkney thanks for sharing this nice code base. I have a question for training the image condition model. Stable diffusion model is trained by randomly dropping 10% prompt (which means the condition prompt is set as empty "") to guarantee the performance of unconditional model, because it utilize the class-free guidance. When training image condition one, is there also such setting? Thanks!

wtliao avatar Apr 20 '23 17:04 wtliao