k-diffusion icon indicating copy to clipboard operation
k-diffusion copied to clipboard

Conditional *image* generation (img2img)

Open CallShaul opened this issue 1 year ago • 2 comments

Hi,

In order to add support for conditional image generation, in addition to the initial image embedding into unet_cond, (extra_args['unet_cond'] = img_cond) what should I put in extra_args['cross_cond'] and extra_args['cross_cond_padding'] ?

(before the loss calculation in the line: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args))

@crowsonkb @nekoshadow1 @brycedrennan

Thanks !

CallShaul avatar Jun 19 '23 11:06 CallShaul

I'd be interested in seeing an answer to this as well. e.g. for the simple case of MNIST, how might we implement (or activate) class-conditional generation?

i see class_cond in the code, and a cond_dropout_rate in the config files, so maybe it's already training that way... But the in the output from demo(), it seems to just be random. Perhaps we just need to change line 369 in train.py from this...

            class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device)

To something more "intentional", such as...

            class_cond = torch.remainder(torch.arange(0, accelerator.num_processes*n_per_proc-1), num_classes).reshape([accelerator.num_processes, n_per_proc]).int().to(device)

....?

Update: yep! That worked! :-)

demo_grid_13499_969d27db3303994e126b

drscotthawley avatar Feb 22 '24 18:02 drscotthawley

Solution:

I've made it work, here's the main steps: (some more workarounds are needed to make it run, in the inference as well, but this is the main idea):

  1. get the conditioned image in each batch training iteration:
unet_cond = get_condition_channels(model_config, img_cond)
extra_args['unet_cond'] = unet_cond.to(device)
  1. modify the "losses" line calculation, and add the image condition there: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args)
  • some fixes are needed to be done in the "forward" function, on the model file image_v1.py
  • perform similar conditioning in the inference stage

CallShaul avatar Apr 18 '24 12:04 CallShaul