diffusers
diffusers copied to clipboard
Support LinFusion
What does this PR do?
Support LinFusion. It accelerates diffusion models by replacing all the self-attention layers in a diffusion UNet with distilled Generalized Linear Attention layers. The distilled model is linear-complexity and highly compatible with existing diffusion plugins like ControlNet, IP-Adapter, LoRA, etc. The acceleration can be dramatic at high resolution. Strategical pipelines for high-resolution generation can be found in the original codebase.
You can use it with only 1 additional line:
import torch
from diffusers import StableDiffusionPipeline
repo_id = "stabilityai/stable-diffusion-2-1"
pipe = StableDiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16, variant="fp16").to("cuda")
+ pipe.load_linfusion(pipeline_name_or_path=repo_id)
image = pipe("a photo of an astronaut on a moon").images[0]
Currently, stable-diffusion-v1-5/stable-diffusion-v1-5
, stabilityai/stable-diffusion-2-1
, stabilityai/stable-diffusion-xl-base-1.0
, models finetuned from them, and pipelines based on them are supported. If the repo_id
is different from them, e.g., when using a fine-tuned model from the community, you need to specify pipeline_name_or_path
explicitly to the model it is based on. Otherwise, this argument is optional and LinFusion will read it from the current pipeline. Alternatively, you can also specify the argument pretrained_model_name_or_path_or_dict
to load LinFusion from other sources. You can also unload it with pipe.unload_linfusion()
when unnecessary.
Accordingly, we also update the doc under docs/source/en/optimization/linfusion.md
for a specific example.
Thanks for your efforts in reviewing this pull request in advance! We are open to any changes to make sure LinFusion can best fit the current diffusers library!
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [ ] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.