diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[Community Pipeline] Checkpoint merging

Open osanseviero opened this issue 1 year ago • 20 comments

Intro

Community Pipelines are introduced in diffusers==0.4.0 with the idea of allowing the community to quickly add, integrate, and share their custom pipelines on top of diffusers.

You can find a guide about Community Pipelines here. You can also find all the community examples under examples/community/. If you have questions about the Community Pipelines feature, please head to the parent issue.

Idea: Checkpoint Merging

This pipeline aims to merge two checkpoints into one via interpolation.

osanseviero avatar Oct 17 '22 14:10 osanseviero

@apolinario any resources you would suggest for this pipeline?

osanseviero avatar Oct 17 '22 14:10 osanseviero

Yes, so checkpoint merging is an idea implemented in the AUTOMATIC1111/stable-diffusion-webui repo, it interpolates the weights of models (can be fine-tuned or different versions of models), potentially creating cool results: image

This is how it is implemented in the original repo

apolinario avatar Oct 18 '22 09:10 apolinario

Would the following be a good start ?

# STEP 1:
# Verify that the checkpoints have the same dimensions/modules etc.
# STEP 2:
# Find the mergeable modules from both the checkpoints. ( vae, unet, safety checker etc.. )
# STEP 3:
# For each mergeable component,  use the interpolation technique on the component weights and update the weights
# STEP 4:
# Return a pipeline with the merged weights.

So you would run it like this:-

pipe = CheckpointMergerPipeline.from_pretrained(chkpt0 =  "sample/checkpoint-1", chkpt1 = "sample/checkpoint-2", alpha = 0.2, interp = "sigmoid")

pipe.save_pretrained()

prompt = "A cat riding a skateboard in an 18th century street at night, moon in the background"

pipe.to("cuda")
pipe(prompt).images[0]

What do you think ? @apolinario @osanseviero

Abhinay1997 avatar Oct 18 '22 15:10 Abhinay1997

That seems sensible to me

patrickvonplaten avatar Oct 20 '22 16:10 patrickvonplaten

Hey @Abhinay1997, are you still working on this? If not, @patrickvonplaten may I please give this a shot? :slightly_smiling_face:

vvvm23 avatar Oct 25 '22 05:10 vvvm23

Hi @vvvm23, Currently working on a solution. Give me 24 hrs to make a PR. I'll let you know if my approach fails and maybe you can pick it up ?

Abhinay1997 avatar Oct 25 '22 05:10 Abhinay1997

No worries! Please take your time!

vvvm23 avatar Oct 25 '22 05:10 vvvm23

@patrickvonplaten, @apolinario

I ran into the following issue and wanted your thoughts:-

When building a custom pipeline inheriting from DiffusionPipeline, for checkpoint merging I have to do it this way:-

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
pipe.merge("CompVis/stable-diffusion-v1-2")

The problem here is that the custom_pipeline is loaded via the from_pretrained method of DiffusionPipeline. If we have to ensure that any kind of checkpoints are mergeable, we need to keep the modules dynamic. However the from_pretrained method expects the custom_pipeline class to declare the kwargs in advance and results in an empty pipeline otherwise.

One solution is to pass the original checkpoint again in the merge method like:-

pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline = "checkpoint_merger_pipeline")
## Merge v1.4 weights with v1.2 weights of stable diffusion.
## SEE EXTRA ARGS HERE
pipe.merge("CompVis/stable-diffusion-v1-4", "CompVis/stable-diffusion-v1-2")

But it would result in the from_pretrained call being redundant other than to return an instance of the custom_pipeline class and result in confusion.

What would you suggest ?

Abhinay1997 avatar Oct 26 '22 15:10 Abhinay1997

Really up to you as it's a community pipeline :-) In general we cannot assume to be able to merge weights at a low level, so I'm also not 100% sure of the usefulness here - @apolinario do you have a good idea?

patrickvonplaten avatar Oct 26 '22 16:10 patrickvonplaten

Just my random thoughts, but why make this a pipeline instead of something functional? For example, a function with args module, ckpt1, ckpt2, alpha?

vvvm23 avatar Oct 27 '22 07:10 vvvm23

Update:- I've reached a point in the solution where validation of the checkpoints is complete and the individual interpolation steps being complete. The only thing that needs to be done is returning a DiffusionPipeline object with the updated weights.

Should be done by the weekend.

Abhinay1997 avatar Nov 03 '22 08:11 Abhinay1997

Hey team. Sorry it took a while. But here's the colab notebook with my code. Need your help in figuring out how to cut down on memory issues. Currently it crashes for the scenario in the notebook ( 12 GB RAM ) Will try it out on Kaggle ( 16 GB RAM ) and see if it's any better.

Abhinay1997 avatar Nov 20 '22 16:11 Abhinay1997

Success ! Able to run on Kaggle. Uses around 13GB RAM for merging Stable Diffusion and Waifu diffusion. I'll make a PR over the weekend. I have made the notebook public if anyone's interested.

https://www.kaggle.com/abhinaydevarinti/checkpoint-merging-huggingface-diffusers

Abhinay1997 avatar Nov 26 '22 04:11 Abhinay1997

Thanks @Abhinay1997

patrickvonplaten avatar Nov 30 '22 12:11 patrickvonplaten

Could more "advanced" merging be supported by similar pipelines?

Namely this, which produces better results when merging models with different "base" SD ancestors: https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion

And this, which allows for specifying different weights for different blocks (with the benefits specified in the linked blog post): https://github.com/bbc-mc/sdweb-merge-block-weighted-gui

brucethemoose avatar Jan 05 '23 19:01 brucethemoose

Gently pinging @Abhinay1997 here :-)

patrickvonplaten avatar Jan 10 '23 13:01 patrickvonplaten

@brucethemoose, I went through the post that you mentioned briefly. From what I gather this method does a permutation of weights and does an optimal match for merging. From one of the repo's linked above I saw that the matching is different depending on the network architecture. However I might be wrong. Let me look into this in more detail.

As for now, this pipeline is a simple interpolation where it blindly merges matched state_dicts as long as they are compatible and skips over incompatible ones. Frankly I think it should be a standalone pipeline on its own. However let me go over this in detail and get back to you. :)

Abhinay1997 avatar Jan 11 '23 17:01 Abhinay1997

@Abhinay1997 @brucethemoose i've implemented block-weighted merging in grate , based on a modified version of checkpoint_merger. check it out here - (https://github.com/damian0815/grate/blob/main/src/sdgrate/checkpoint_merger_mbw.py) (or just pip install sdgrate and then run grate). Maybe my changes could be folded into your one @Abhinay1997 ?

damian0815 avatar Feb 19 '23 13:02 damian0815

Damian,

Your implementation of the block-weighted merging is super cool ! My motivation with CheckpointMergingPipeline was the hope that it would work as a general purpose merger for all modules of the passed checkpoint. In line with this, I think we can make the module check ( UNet2DConditionModel ) to be dynamic ( another argument to the merge method ) and the block_weights a nested dict instead.

Of course, these are my thoughts. Would like to know what you think

Abhinay1997 avatar Feb 19 '23 14:02 Abhinay1997

@Abhinay1997 i made a PR https://github.com/huggingface/diffusers/pull/2422

damian0815 avatar Feb 19 '23 14:02 damian0815