diffusers
diffusers copied to clipboard
[Community Pipeline] Checkpoint merging
Intro
Community Pipelines are introduced in diffusers==0.4.0
with the idea of allowing the community to quickly add, integrate, and share their custom pipelines on top of diffusers
.
You can find a guide about Community Pipelines here. You can also find all the community examples under examples/community/
. If you have questions about the Community Pipelines feature, please head to the parent issue.
Idea: Checkpoint Merging
This pipeline aims to merge two checkpoints into one via interpolation.
@apolinario any resources you would suggest for this pipeline?
Yes, so checkpoint merging is an idea implemented in the AUTOMATIC1111/stable-diffusion-webui repo, it interpolates the weights of models (can be fine-tuned or different versions of models), potentially creating cool results:
This is how it is implemented in the original repo
Would the following be a good start ?
# STEP 1:
# Verify that the checkpoints have the same dimensions/modules etc.
# STEP 2:
# Find the mergeable modules from both the checkpoints. ( vae, unet, safety checker etc.. )
# STEP 3:
# For each mergeable component, use the interpolation technique on the component weights and update the weights
# STEP 4:
# Return a pipeline with the merged weights.
So you would run it like this:-
pipe = CheckpointMergerPipeline.from_pretrained(chkpt0 = "sample/checkpoint-1", chkpt1 = "sample/checkpoint-2", alpha = 0.2, interp = "sigmoid")
pipe.save_pretrained()
prompt = "A cat riding a skateboard in an 18th century street at night, moon in the background"
pipe.to("cuda")
pipe(prompt).images[0]
What do you think ? @apolinario @osanseviero
That seems sensible to me
Hey @Abhinay1997, are you still working on this? If not, @patrickvonplaten may I please give this a shot? :slightly_smiling_face:
Hi @vvvm23, Currently working on a solution. Give me 24 hrs to make a PR. I'll let you know if my approach fails and maybe you can pick it up ?
No worries! Please take your time!
@patrickvonplaten, @apolinario
I ran into the following issue and wanted your thoughts:-
When building a custom pipeline inheriting from DiffusionPipeline, for checkpoint merging I have to do it this way:-
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
pipe.merge("CompVis/stable-diffusion-v1-2")
The problem here is that the custom_pipeline is loaded via the from_pretrained
method of DiffusionPipeline. If we have to ensure that any kind of checkpoints are mergeable, we need to keep the modules dynamic. However the from_pretrained
method expects the custom_pipeline class to declare the kwargs in advance and results in an empty pipeline otherwise.
One solution is to pass the original checkpoint again in the merge method like:-
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
## SEE EXTRA ARGS HERE
pipe.merge("CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-2")
But it would result in the from_pretrained call being redundant other than to return an instance of the custom_pipeline class and result in confusion.
What would you suggest ?
Really up to you as it's a community pipeline :-) In general we cannot assume to be able to merge weights at a low level, so I'm also not 100% sure of the usefulness here - @apolinario do you have a good idea?
Just my random thoughts, but why make this a pipeline instead of something functional? For example, a function with args module
, ckpt1
, ckpt2
, alpha
?
Update:- I've reached a point in the solution where validation of the checkpoints is complete and the individual interpolation steps being complete. The only thing that needs to be done is returning a DiffusionPipeline object with the updated weights.
Should be done by the weekend.
Hey team. Sorry it took a while. But here's the colab notebook with my code. Need your help in figuring out how to cut down on memory issues. Currently it crashes for the scenario in the notebook ( 12 GB RAM ) Will try it out on Kaggle ( 16 GB RAM ) and see if it's any better.
Success ! Able to run on Kaggle. Uses around 13GB RAM for merging Stable Diffusion and Waifu diffusion. I'll make a PR over the weekend. I have made the notebook public if anyone's interested.
https://www.kaggle.com/abhinaydevarinti/checkpoint-merging-huggingface-diffusers
Thanks @Abhinay1997
Could more "advanced" merging be supported by similar pipelines?
Namely this, which produces better results when merging models with different "base" SD ancestors: https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
And this, which allows for specifying different weights for different blocks (with the benefits specified in the linked blog post): https://github.com/bbc-mc/sdweb-merge-block-weighted-gui
Gently pinging @Abhinay1997 here :-)
@brucethemoose, I went through the post that you mentioned briefly. From what I gather this method does a permutation of weights and does an optimal match for merging. From one of the repo's linked above I saw that the matching is different depending on the network architecture. However I might be wrong. Let me look into this in more detail.
As for now, this pipeline is a simple interpolation where it blindly merges matched state_dicts as long as they are compatible and skips over incompatible ones. Frankly I think it should be a standalone pipeline on its own. However let me go over this in detail and get back to you. :)
@Abhinay1997 @brucethemoose i've implemented block-weighted merging in grate
, based on a modified version of checkpoint_merger. check it out here - (https://github.com/damian0815/grate/blob/main/src/sdgrate/checkpoint_merger_mbw.py) (or just pip install sdgrate
and then run grate
). Maybe my changes could be folded into your one @Abhinay1997 ?
Damian,
Your implementation of the block-weighted merging is super cool ! My motivation with CheckpointMergingPipeline was the hope that it would work as a general purpose merger for all modules of the passed checkpoint. In line with this, I think we can make the module check ( UNet2DConditionModel ) to be dynamic ( another argument to the merge method ) and the block_weights a nested dict instead.
Of course, these are my thoughts. Would like to know what you think
@Abhinay1997 i made a PR https://github.com/huggingface/diffusers/pull/2422