k-diffusion
k-diffusion copied to clipboard
Conditional *image* generation (img2img)
Hi,
In order to add support for conditional image generation, in addition to the initial image embedding into unet_cond
,
(extra_args['unet_cond'] = img_cond
) what should I put in extra_args['cross_cond']
and extra_args['cross_cond_padding']
?
(before the loss calculation in the line: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args))
@crowsonkb @nekoshadow1 @brycedrennan
Thanks !
I'd be interested in seeing an answer to this as well. e.g. for the simple case of MNIST, how might we implement (or activate) class-conditional generation?
i see class_cond
in the code, and a cond_dropout_rate
in the config files, so maybe it's already training that way... But the in the output from demo()
, it seems to just be random. Perhaps we just need to change line 369 in train.py from this...
class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device)
To something more "intentional", such as...
class_cond = torch.remainder(torch.arange(0, accelerator.num_processes*n_per_proc-1), num_classes).reshape([accelerator.num_processes, n_per_proc]).int().to(device)
....?
Update: yep! That worked! :-)
Solution:
I've made it work, here's the main steps: (some more workarounds are needed to make it run, in the inference as well, but this is the main idea):
- get the conditioned image in each batch training iteration:
unet_cond = get_condition_channels(model_config, img_cond)
extra_args['unet_cond'] = unet_cond.to(device)
- modify the "losses" line calculation, and add the image condition there:
losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args)
- some fixes are needed to be done in the "forward" function, on the model file image_v1.py
- perform similar conditioning in the inference stage