diffusers
diffusers copied to clipboard
Train lcm distil instruct pix2pix sdxl
What does this PR do?
This PR adds a training script for Latent Consistency Model (LCM) distillation applied to InstructPix2Pix with Stable Diffusion XL. This enables fast, few-step image editing (1-4 steps) while maintaining high-quality outputs from instruction-based editing.
Key Features
- LCM Distillation Pipeline: Implements teacher-student distillation where a pre-trained InstructPix2Pix SDXL model (teacher) guides training of a lightweight student model capable of single-step inference
- 8-Channel U-Net Support: Properly handles InstructPix2Pix's concatenated input (noisy latent + original image latent)
- Time Conditioning: Adds guidance scale embedding to student U-Net for flexible inference
- EMA Target Network: Uses exponential moving average for stable training targets
- DDIM Solver Integration: Implements multi-step teacher predictions with classifier-free guidance
- Flexible Loss Functions: Supports both L2 and Huber loss for robust training
- Production-Ready: Includes validation, checkpointing, mixed precision, gradient checkpointing, and xFormers support
Training Algorithm
- Sample timestep from DDIM schedule
- Add noise to latents and sample guidance scale $w \in [w_{min}, w_{max}]$
- Student makes single-step prediction from noisy latents
- Teacher performs multi-step DDIM prediction with CFG
- Target network (EMA of student) generates stable training target
- Compute loss between student and target predictions
- Update student parameters and EMA update target network
Use Case
This script allows researchers and practitioners to create fast InstructPix2Pix SDXL models that can perform high-quality image editing in just 4 inference steps instead of 50+, making real-time image editing applications feasible.
Who can review?
@yiyixuxu