Dreambooth
Dreambooth copied to clipboard
Accelerate accumulation context for multiple models not supported yet
trafficstars
https://github.com/huggingface/accelerate/issues/668
Do this manually?
for step, batch in enumerate(train_dataloader): # TODO: how to handle context setting when unet is not training? # https://stackoverflow.com/a/14029481 #train_all = train_unet and train_text_encoder #with (accelerator.accumulate(unet), accelerator.accumulate(text_encoder)) if train_all else (accelerator.accumulate(unet) if train_unet else accelerator.accumulate(text_encoder)) with accelerator.accumulate(unet):