How to check a checkpoints output quality?
If I use --checkpointing_steps for dreambooth training, is there a way to check the output/quality of those checkpoints? Can I automatically create sample images at each checkpoint? Since the description says "These checkpoints are only suitable for resuming training using --resume_from_checkpoint", is there a way to convert the checkpoints into something usable for inference, like a diffusers model or stable diffusion model? Or do I have to train until x steps, check the models output, then resume until x+y, check output, etc.?
Hi @djdookie! Thanks for raising this question, that was actually a bad explanation I wrote. It is indeed possible to convert a checkpoint to an inference pipeline using something like this:
from accelerate import Accelerator
from diffusers import DiffusionPipeline
# Load the pipeline with the same arguments (model, revision) that were used for training
model_id = "stabilityai/stable-diffusion-2-1"
pipeline = DiffusionPipeline.from_pretrained(model_id)
accelerator = Accelerator()
# Use text_encoder if `--train_text_encoder` was used
unet, text_encoder = accelerator.prepare(pipeline.unet, pipeline.text_encoder)
# Restore state from a checkpoint path
accelerator.load_state("/sddata/dreambooth/daruma-v2-1/checkpoint-100")
# Rebuild the pipeline with the unwrapped models (assignment to .unet and .text_encoder should work too)
pipeline = DiffusionPipeline.from_pretrained(
model_id,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
)
# Perform inference, or save, or push to the hub
pipeline.save_pretrained("dreambooth-pipeline")
I'll update the docstring and write a section in the documentation to clarify. Thanks!
Nice, thank you @pcuenca
Is there a way to do this during training, using the model that's already in memory rather than loading it up from a checkpoint?
@David-Hari good question. @patil-suraj I think we should add some code that allows for generating images during training. Also cc @pcuenca and @williamberman for Dreambooth
I'll open a new issue for it.
We should definitely add some image logging in the script; adding this to my todo-list.