diffusers
diffusers copied to clipboard
[training examples] reduce complexity by running final validations before export
I was thinking about Sayak's suggestion lately that the training examples are too long, and went through looking for redundant/unnecessary code sections that can be reduced or eliminated for readability.
The main thing that stands out is how the validations occur during the trainer unwind stage.
During training, we have access to the unet and other components - we pass is_final_validation=False
to the log_validations
method, which behaves differently across different training examples. In the ControlNet example, it ends up importing the ControlNet model as a pipeline from the args.output_dir
:
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ")
if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is not None:
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
else:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
this seems to only happen because at the end of training, this method is called after everything is unloaded:
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir)
# Run a final round of validation.
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=None,
unet=None,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)
- theoretically it shows us the results of the final export, but in practice it's the same result as if we inference on the loaded weights without reloading them
- this particular case has the unet/vae/controlnet model loaded twice, as it would need to
del
them before loading the new ones - when
max_train_steps = 1000
andvalidation_steps = 100
or some other value that goes evenly intomax_train_steps
, we run two validations - one just before exiting the training loop, and then this one - unnecessary slowdown on systems that (for no reason I can discern) take a very long time to load pipelines
If we just remove the final inference code, the earlier condition can be updated to run the validation before exiting the loop, which would solve issues 2-4:
From:
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
image_logs = log_validation(...)
To:
if args.validation_prompt is not None and (global_step % args.validation_steps == 0 or global_step >= args.max_train_steps):
image_logs = log_validation(...)