Add support for Magcache
What does this PR do?
This PR adds support for MagCache (Magnitude-aware Cache), a training-free inference acceleration method for diffusion models, specifically targeting Transformer-based architectures like Flux.
This implementation follows the ModelHook pattern (similar to FirstBlockCache) to integrate seamlessly into Diffusers.
Key features:
MagCacheConfig: Configuration class to control threshold, retention ratio, and skipping limits.- Calibration Mode: Adds a
calibrate=Trueflag. When enabled, the hook runs full inference and calculates/prints the magnitude ratios for the specific model and scheduler. This makes MagCache compatible with any transformer model (e.g., Hunyuan, Wan, SD3), not just Flux. - Strict Validation: To ensure correctness across different models,
mag_ratiosmust be explicitly provided in the config (or calibration enabled). - Flux Support: Includes pre-computed
FLUX_MAG_RATIOSas a constant for convenience, derived from the official implementation. - Mechanism: The hook calculates the accumulated error of the residual magnitude at each step. If the error is below the defined threshold, it skips the computation of the transformer blocks and approximates the output using the residual from the previous step. Fixes #12697
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case. (#12697)
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
@sayakpaul
@leffff could you review as well if possible?
Hi @AlanPonnachan @sayakpaul The thing with MagCache is it requires precomputing magnitudes. @AlanPonnachan has done it for Flux, but how will this work for other models?
@AlanPonnachan ?
@leffff , Thank you for your review.
To address this, I am implementing a Calibration Mode.
My plan is to add a calibrate=True flag to MagCacheConfig. When enabled:
- The pipeline runs full inference (no skipping).
- The hook calculates the residual magnitude ratios at every step.
- At the end of inference, it logs/prints the resulting array of ratios.
Users can then simply run one calibration pass for their specific model/scheduler, copy the output ratios, and pass them into MagCacheConfig(mag_ratios=[...]) for optimized inference. This makes the implementation completely model-agnostic.
I am working on this update now and will push the changes shortly!
@leffff , Thank you for your review.
To address this, I am implementing a Calibration Mode.
My plan is to add a
calibrate=Trueflag toMagCacheConfig. When enabled:
- The pipeline runs full inference (no skipping).
- The hook calculates the residual magnitude ratios at every step.
- At the end of inference, it logs/prints the resulting array of ratios.
Users can then simply run one calibration pass for their specific model/scheduler, copy the output ratios, and pass them into
MagCacheConfig(mag_ratios=[...])for optimized inference. This makes the implementation completely model-agnostic.I am working on this update now and will push the changes shortly!
Sounds great! I am not a Diffusers maintainer, but i believe, making such calibration will indeed make it universal. (This is similar to compiling). I believe after this update, this will be completely usable!
Thanks for the thoughtful discussions here @AlanPonnachan and @leffff! I will leave my two cents below:
- The calibration steps outlined in https://github.com/huggingface/diffusers/pull/12744#issuecomment-3609618101 are great! What if we ship a utility to just log/print those values so that users can pass them to the
MagCacheConfig? We could provide that utility script either fromscriptsor fromsrc/diffusers/utils. I think this will be more explicit and enforce some kind of user awareness. - If the
mag_ratiosare supposed to checkpoint-dependent, I think we should always enforce passingmag_ratiosfrom the config and when not provided, raise a sensible error message that instructs the user on how to derivemag_ratios.
Ccing @DN6 to get his thoughts here, too.
Thanks @sayakpaul and @leffff for the feedback!
I have updated the PR to address these points. Instead of a standalone utility script, I integrated the calibration logic directly into the hook configuration for better usability:
- Strict Enforcement:
mag_ratiosis now mandatory. If not provided (andcalibrate=False), aValueErroris raised with instructions on how to derive them. - Calibration Mode: I added a
calibrate=Trueflag toMagCacheConfig. When enabled, the hooks run full inference (no skipping) and log/print the calculated magnitude ratios at the end. This allows users to easily generate ratios for any model/scheduler combination using their existing pipeline code. - Flux Convenience: I kept
FLUX_MAG_RATIOSas a constant for convenience, but the user must now explicitly import and pass it to the config.
Ready for review!
Looks Great! Could you please provide a usage example:
- import & load a specific model
- Inference
- Calibrate
- Inference w MagCahce
And Provide Generations
To be Sure it works, please provide generations for SD3.5 Medium, Flux, Wan T2V 2.1 1.3b I also believe, as caching is suitable for all tasks, can we also try Kandinsky 5.0 Video Pro I2V kandinskylab/Kandinsky-5.0-I2V-Pro-sft-5s-Diffusers
@leffff
1. Usage Example
import torch
from diffusers import FluxPipeline from diffusers.hooks import MagCacheConfig, apply_mag_cache
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
# CALIBRATION STEP
config = MagCacheConfig(calibrate=True, num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)
# Logs: [1.0, 1.37, 0.97, 0.87]
# INFERENCE STEP
config = MagCacheConfig(mag_ratios=[1.0, 1.37, 0.97, 0.87], num_inference_steps=4)
apply_mag_cache(pipe.transformer, config)
pipe("A cat playing chess", num_inference_steps=4)
2. Benchmark Results
I validated the implementation on Flux, SD 3.5, and Wan 2.1 using a T4 Colab environment.
| Model | Baseline Time | MagCache Time | Speedup | Generated Ratios (First 5) | Notes |
|---|---|---|---|---|---|
| Flux.1-Schnell | ~10m 31s | ~7m 55s | ~1.33x | [1.0, 1.371991753578186, 0.9733748435974121, 0.8640348315238953] |
Full generation successful. |
| SD 3.5 Medium | ~4m 46s ~4m 51s |
~1m 36s (threshold = 0.15) ~2m 43s (threshold = 0.03) |
~3.0x (threshold = 0.15) ~1.79x (threshold = 0.03) |
threshold = 0.15: [1.0, 1.0182535648345947, 1.0475366115570068, 1.0192866325378418, 1.007051706314087, 1.013611078262329, 1.0057004690170288, 1.0053653717041016, 0.9967299699783325, 0.9996473789215088, 0.9947380423545837, 0.9942205548286438, 0.9788764715194702, 0.9873758554458618, 0.9801908731460571, 0.9658506512641907, 0.9565740823745728, 0.9469784498214722, 0.9258849620819092, 1.3470091819763184] threshold = 0.03: [1.0, 1.0172510147094727, 1.0381698608398438, 1.0167241096496582, 1.0070651769638062, 1.0107033252716064, 1.0043275356292725, 1.0044840574264526, 0.9945924282073975, 0.9993497133255005, 0.9941253662109375, 0.9904510974884033, 0.9783601760864258, 0.9845271110534668, 0.9771078824996948, 0.9657461047172546, 0.9529474973678589, 0.9403719305992126, 0.9110836982727051, 1.3032703399658203] |
Validated hooks without T5 encoder (RAM limit). |
| Wan 2.1 (1.3B) | ~22s | ~1s | ~22x | [1.0, 0.9901599884033203, 0.9980327486991882, 1.001886248588562, 1.0045758485794067, 1.0067006349563599, 1.0093395709991455, 1.0129660367965698, 1.0191177129745483, 1.0308380126953125] |
Validated hooks with dummy embeddings (RAM limit). |
| Kandinsky 5.0 | N/A | N/A | N/A | N/A | Added visual_transformer_blocks support, but hit disk limits. Logic matches Wan/Flux. |
3. Generations
Attached below are the outputs for the successful runs.
Flux (Baseline):
Flux (MagCache):
SD 3.5 (Baseline):
SD 3.5 (MagCache -- threshold = 0.15):
SD 3.5 (Baseline):
SD 3.5 (MagCache -- threshold = 0.03):
Here is the Colab notebook used to generate the benchmarks above. It includes the full setup, memory optimizations (sequential offloading/dummy embeds), and the execution logs:
@bot /style
Style bot fixed some files and pushed the changes.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
This looks good! Thank you for conducting measurements and testing different models.
@AlanPonnachan thanks for your great work thus far! Some minor questions (mostly out of curiosity below):
-
Do we need to pass
num_inference_stepsalong withmag_ratios? Just passingmag_ratiosshould be enough no? -
Can we also check if the
enable_cache()method on the transformer also works as expected? -
Would the
mag_ratioscomputed for a higher number ofnum_inference_stepswork for a lower number ofnum_inference_steps? -
I wonder why there is a RAM limit when using SD3.5 Medium and not Flux. I ask this because Flux's DiT is larger and the largest text encoder in both cases is the T5.
-
Added visual_transformer_blocks support, but hit disk limits. Logic matches Wan/Flux.
Is it the Colab disk limit?
Additionally, I could obtain outputs with Wan 1.3B and they look reasonable to me.
Code
import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
num_inference_steps = 50
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
# apply_mag_cache(pipe.transformer, config)
config = MagCacheConfig(
mag_ratios=[1.0, 1.0337707996368408, 0.9908783435821533, 0.9898878931999207, 0.990186870098114, 0.989551305770874, 0.9898356199264526, 0.9901290535926819, 0.9913457632064819, 0.9893063902854919, 0.990695059299469, 0.9892956614494324, 0.9910416603088379, 0.9908630847930908, 0.9897039532661438, 0.9907404184341431, 0.98955237865448, 0.9905906915664673, 0.9881031513214111, 0.98977130651474, 0.9878108501434326, 0.9873648285865784, 0.98862624168396, 0.9870336055755615, 0.9855726957321167, 0.9857151508331299, 0.98496013879776, 0.9846605658531189, 0.9835416674613953, 0.984062671661377, 0.9805435538291931, 0.9828993678092957, 0.9804039001464844, 0.9776313304901123, 0.9769471883773804, 0.9752448201179504, 0.973810076713562, 0.9708614349365234, 0.9703076481819153, 0.9666262865066528, 0.9658275246620178, 0.9612534046173096, 0.9553734064102173, 0.9522399306297302, 0.9467942118644714, 0.9430344104766846, 0.9335862994194031, 0.9285727739334106, 0.9244886636734009, 0.9560992121696472],
num_inference_steps=num_inference_steps
)
apply_mag_cache(pipe.transformer, config)
prompt = "A cat walks on the grass, realistic"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
).frames[0]
export_to_video(output, "output.mp4", fps=15)
Outputs:
# Calibation
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 50/50 [01:35<00:00, 1.91s/it]
# After using the `mag_ratios`
100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:27<00:00, 1.82it/s]
Video output:
https://github.com/user-attachments/assets/6b63a00d-bcf3-41ff-a2b8-5fac26d17bcf
However, there seems to be a problem when using Kandinsky 5 and the error seems obvious to me.
Error: https://pastebin.com/F7arxTWg
Code
import torch
from diffusers import Kandinsky5T2VPipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
from diffusers.utils import export_to_video
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
num_inference_steps = 50
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe = pipe.to("cuda")
config = MagCacheConfig(calibrate=True, num_inference_steps=num_inference_steps)
apply_mag_cache(pipe.transformer, config)
# config = MagCacheConfig(
# mag_ratios=[...],
# num_inference_steps=num_inference_steps
# )
# apply_mag_cache(pipe.transformer, config)
prompt = "A cat and a dog baking a cake together in a kitchen."
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=512,
width=768,
num_frames=121, # ~5 seconds at 24fps
num_inference_steps=num_inference_steps,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "output_kandinsky.mp4", fps=24, quality=9)
For this, instead of a line like the following https://github.com/AlanPonnachan/diffusers/blob/ebbebbefac42984b1d09e63a17b6d7f11af66073/src/diffusers/hooks/mag_cache.py#L186
maybe we could pass it to the cache config? I understand this could be difficult for the users but my thought is since they have to perform calibration anyway, this is still reasonable?
Just for curiosity, I changed to:
diff --git a/src/diffusers/hooks/mag_cache.py b/src/diffusers/hooks/mag_cache.py
index 71ebfcb25..0a7c333db 100644
--- a/src/diffusers/hooks/mag_cache.py
+++ b/src/diffusers/hooks/mag_cache.py
@@ -183,7 +183,7 @@ class MagCacheHeadHook(ModelHook):
self.state_manager.set_context("inference")
# Capture input hidden_states
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
state: MagCacheState = self.state_manager.get_state()
state.head_block_input = hidden_states
@@ -297,7 +297,7 @@ class MagCacheBlockHook(ModelHook):
state: MagCacheState = self.state_manager.get_state()
if not state.should_compute:
- hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ hidden_states = self._metadata._get_parameter_from_args_kwargs("visual_embed", args, kwargs)
if self.is_tail:
# Still need to advance step index even if we skip
self._advance_step(state)
And ran the above code. But I am getting a pair of mag_ratios now:
Unfold
[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0096147060394287, 0.8601706027984619, 1.0066865682601929, 1.1018145084381104, 1.0066889524459839, 1.07235848903656, 1.006271243095398, 1.0583757162094116, 1.0066468715667725, 1.0803261995315552, 1.0059221982955933, 1.0304542779922485, 1.0061317682266235, 1.0251237154006958, 1.006355881690979, 1.0230522155761719, 1.0063568353652954, 1.0354706048965454, 1.006076455116272, 1.0154225826263428, 1.0064369440078735, 1.0257697105407715, 1.0066747665405273, 1.012341856956482, 1.0068379640579224, 1.017471432685852, 1.0070058107376099, 1.008599877357483, 1.00702702999115, 1.0158008337020874, 1.0070949792861938, 1.0113613605499268, 1.0063375234603882, 1.0122487545013428, 1.0064034461975098, 1.0091496706008911, 1.0062494277954102, 1.0109937191009521, 1.0061204433441162, 1.0084550380706787, 1.0059889554977417, 1.006821870803833, 1.0058847665786743, 1.0106556415557861, 1.005847454071045, 1.0057544708251953, 1.0058276653289795, 1.0092748403549194, 1.005746841430664]
[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):
[1.0, 1.0056898593902588, 1.0074970722198486, 1.005563735961914, 1.0061627626419067, 1.0054070949554443, 1.0053973197937012, 1.0052893161773682, 1.0067739486694336, 1.0051906108856201, 1.0049010515213013, 1.0050380229949951, 1.0056493282318115, 1.0049028396606445, 1.0056771039962769, 1.0048167705535889, 1.0038255453109741, 1.0047082901000977, 1.0041747093200684, 1.004562258720398, 1.002451777458191, 1.0044060945510864, 1.0022073984146118, 1.0042728185653687, 1.0011045932769775, 1.0041989088058472, 0.9996317625045776, 1.0040632486343384, 0.9980409741401672, 1.0038821697235107, 0.9960299134254456, 1.004146933555603, 0.9924721717834473, 1.0041824579238892, 0.9876144528388977, 1.0041331052780151, 0.9839898943901062, 1.003833293914795, 0.976319432258606, 1.0032036304473877, 0.9627748131752014, 1.002505898475647, 0.9450504779815674, 1.001646637916565, 0.9085856080055237, 0.9999536275863647, 0.8368133306503296, 0.9975034594535828, 0.6354470252990723, 0.9997955560684204]
When applying the first one, I got:
https://github.com/user-attachments/assets/47c1572a-ea31-4909-8bea-56c23565b7bd
When applying the second one, I got:
https://github.com/user-attachments/assets/8eb466b9-7fcf-4145-b0e7-e6b7ed4c7e29
Thought this would help :)
@sayakpaul thank you for running inferences from your side, it helped a lot.
1. Regarding num_inference_steps:
It is required because the hook needs to know the "target length" to:
- Reset its state (
accumulated_error,step_index) at the exact end of the generation loop. - Interpolate the provided
mag_ratios(which might be calibrated for 50 steps) down to the current run (e.g., 20 steps).
2. Regarding enable_cache():
I have updated src/diffusers/models/cache_utils.py to support MagCacheConfig. You can now use:
pipe.transformer.enable_cache(MagCacheConfig(...)). (tested)
3. A higher number of num_inference_steps work for a lower number of num_inference_steps?
Yes. The mag_ratios represent the model's intrinsic "magnitude decay curve" (how the residual strength changes from noise to image). This curve's shape is consistent regardless of the step count.
In fact, calibrating on a higher number of steps (e.g., 50) is often better because it creates a high-resolution map of this curve.
Our implementation handles this automatically in MagCacheConfig.__post_init__ using nearest_interp. If you provide 50 ratios but run 10 steps, the code correctly downsamples the curve to pick the 10 representative values that match the current schedule.
4. Regarding ram limit
When i checked , while Flux has a larger DiT, the bottleneck here is System RAM during loading. Flux uses two text encoders (CLIP-L and T5-XXL), whereas SD3.5 loads three (CLIP-L, T5-XXL, and OpenCLIP-BigG). I guess that may be reason for colab crash.
5. yes colab disk
6. Regarding Kandinsky 5
You are rightβthe ValueError is because Kandinsky uses visual_embed instead of hidden_states.
I have fixed this by:
- Adding
hidden_states_argument_nametoTransformerBlockMetadata(defaulting to"hidden_states"so no regressions occur for other models). - Registering
Kandinsky5TransformerDecoderBlockwithhidden_states_argument_name="visual_embed". - Updating the hooks to read this name from metadata dynamically.
7. Regarding the 2 Arrays of magratios in Kandinsky 5.
I checked Kandinsky pipeline, I think this happens because the Kandinsky pipeline runs sequential CFG (calling the transformer twice per step). The first array corresponds to the Conditional pass, and the second to the Unconditional pass. Users can typically pick the first array (Conditional) for calibration.
When i checked , while Flux has a larger DiT, the bottleneck here is System RAM during loading. Flux uses two text encoders (CLIP-L and T5-XXL), whereas SD3.5 loads three (CLIP-L, T5-XXL, and OpenCLIP-BigG). I guess that may be reason for colab crash.
Makes sense, yeah!
Interpolate the provided mag_ratios (which might be calibrated for 50 steps) down to the current run (e.g., 20 steps).
Yes. The mag_ratios represent the model's intrinsic "magnitude decay curve" (how the residual strength changes from noise to image). This curve's shape is consistent regardless of the step count.
This is awesome. Let's make sure we document it once we're at that point.
I checked Kandinsky pipeline, I think this happens because the Kandinsky pipeline runs sequential CFG (calling the transformer twice per step). The first array corresponds to the Conditional pass, and the second to the Unconditional pass. Users can typically pick the first array (Conditional) for calibration.
Okay then this needs to be documented as well. However, there are some small models where we run CFG in a batched manner. Would that affect mag_ratios?
Cc: @Zehong-Ma!
Hey maybe you would like to review the PR as well :)
When i checked , while Flux has a larger DiT, the bottleneck here is System RAM during loading. Flux uses two text encoders (CLIP-L and T5-XXL), whereas SD3.5 loads three (CLIP-L, T5-XXL, and OpenCLIP-BigG). I guess that may be reason for colab crash.
Makes sense, yeah!
Interpolate the provided mag_ratios (which might be calibrated for 50 steps) down to the current run (e.g., 20 steps).
Yes. The mag_ratios represent the model's intrinsic "magnitude decay curve" (how the residual strength changes from noise to image). This curve's shape is consistent regardless of the step count.
This is awesome. Let's make sure we document it once we're at that point.
I checked Kandinsky pipeline, I think this happens because the Kandinsky pipeline runs sequential CFG (calling the transformer twice per step). The first array corresponds to the Conditional pass, and the second to the Unconditional pass. Users can typically pick the first array (Conditional) for calibration.
Okay then this needs to be documented as well. However, there are some small models where we run CFG in a batched manner. Would that affect
mag_ratios?Cc: @Zehong-Ma!
Hey maybe you would like to review the PR as well :)
Thanks for your review and the contribution of @AlanPonnachan . I have briefly reviewed the pull request. Most of your discussion are correct and concise. There may be two important things that should be clearly discussed or fixed.
- Support CFG for different diffusion models. First, we should check whether the model utilizes the CFG. Secondly, if the CFG is adopted, does the Conditional pass and Unconditional pass are processed in a sequential manner or batched manner? In my opinion, a good solution is to maintain two distinct magcache states for unconditional and conditional pass, separately.
- The default MagCache config with the following values may be better for all diffusion models. The previous default values may sacrifice quality to achieve a extremely fast speed.
threshold: float = 0.06
max_skip_steps: int = 3
retention_ratio: float = 0.2
num_inference_steps: int = 28
- Support CFG for different diffusion models. First, we should check whether the model utilizes the CFG. Secondly, if the CFG is adopted, does the Conditional pass and Unconditional pass are processed in a sequential manner or batched manner? In my opinion, a good solution is to maintain two distinct magcache states for unconditional and conditional pass, separately.
From my observation of the codebase, I found that maintaining distinct states for conditional/unconditional passes is quite difficult with the current architecture. In sequential CFG pipelines (like Kandinsky), the transformer is called twice per step. However, the hooks attached to the transformer blocks are "blind" to the pipeline's loopβthey don't receive any signal indicating whether the current forward pass is for the conditional or unconditional branch. To support distinct states, we would likely need to modify the pipelines to explicitly pass this context down to the hooks. Since that would be a much larger change affecting core logic, I focused on documenting this behavior instead. I have added a TIP in the documentation explaining that sequential CFG models may produce two calibration arrays, and users should typically use the first one (Conditional).
- The default MagCache config with the following values may be better for all diffusion models. The previous default values may sacrifice quality to achieve a extremely fast speed.
threshold: float = 0.06 max_skip_steps: int = 3 retention_ratio: float = 0.2 num_inference_steps: int = 28
updated with this config
@sayakpaul , added documentation and test. Please check and let me know any changes required.
To support distinct states, we would likely need to modify the pipelines to explicitly pass this context down to the hooks. Since that would be a much larger change affecting core logic, I focused on documenting this behavior instead. I have added a TIP in the documentation explaining that sequential CFG models may produce two calibration arrays, and users should typically use the first one (Conditional).
Thanks for this (and I agree with this approach)! Should we also document what the users would need to do / how to proceed when the CFG is implemented in the batched manner (SDXL, for example)?
Additionally, @Zehong-Ma WDYT?
@Meatfucker if you want to test the PR.
@bot /style
Style bot fixed some files and pushed the changes.
Also, if you have a chance to test it with torch.compile and report any performance gains from that, it'd be golden. Not a priority, though.
I ran the torch.compile benchmark locally on an RTX 5050.
I used the below script:
import torch
import time
import gc
from diffusers import StableDiffusion3Pipeline
from diffusers.hooks import MagCacheConfig, apply_mag_cache
torch.set_float32_matmul_precision('high')
def flush():
gc.collect()
torch.cuda.empty_cache()
print(f" Benchmarking on {torch.cuda.get_device_name(0)}...")
# Load SD3.5 (No T5 to fit in 8GB VRAM)
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3.5-medium",
text_encoder_3=None,
tokenizer_3=None,
torch_dtype=torch.bfloat16
).to("cuda")
# Setup MagCache
steps = 20
config = MagCacheConfig(
mag_ratios=[1.0] * steps,
num_inference_steps=steps,
threshold=0.05
)
apply_mag_cache(pipe.transformer, config)
# Resolution: 512x512 is safe for 8GB. 1024x1024 will likely OOM.
kwargs = {"height": 512, "width": 512, "num_inference_steps": steps}
prompt = "A photo of a fast car"
# --- RUN 1: EAGER MODE ---
print("\n>> [1/2] Benchmarking MagCache (Eager)...")
# Warmup
pipe(prompt, **kwargs)
torch.cuda.synchronize()
start = time.time()
pipe(prompt, **kwargs)
torch.cuda.synchronize()
eager_time = time.time() - start
print(f" Eager Time: {eager_time:.4f}s")
# --- RUN 2: COMPILED MODE ---
print("\n>> [2/2] Benchmarking MagCache (torch.compile)...")
print(" Compiling transformer ...")
# 'max-autotune' gives best speed but uses more memory/time.
# 'reduce-overhead' is safer for 8GB VRAM.
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=False)
try:
# Warmup (Trigger compilation)
start_compile = time.time()
pipe(prompt, **kwargs)
print(f" Compilation + Warmup took: {time.time() - start_compile:.2f}s")
# Benchmark
torch.cuda.synchronize()
start = time.time()
pipe(prompt, **kwargs)
torch.cuda.synchronize()
compile_time = time.time() - start
print(f" Compile Time: {compile_time:.4f}s")
print(f"\nSpeedup: {eager_time / compile_time:.2f}x")
except Exception as e:
print(f"\nCompilation {e}")
Results (SD 3.5 Medium, 512px):
- Compatibility: Success. The model compiled and ran without errors.
- Performance: The results fluctuate between ~0.9x and ~1.2x across multiple runs.
successful execution confirms that @torch.compiler.disable correctly isolates the dynamic hook logic from the graph. I think the fluctuation suggests that for this specific small workload (512px), the overhead of context-switching between the python hook and the compiled graph is roughly equivalent to the speed gains from kernel fusion. On larger, compute-bound workloads (e.g., High-Res Video on A100s), the compilation benefits would likely outweigh this constant overhead, resulting in a consistent speedup.
This is looking good to me but I am kinda debating if we should have a single unified hook class instead of maintaining two. This >will reduce the complexity a bit and make it likely simpler.
Here's the signature I was thinking in my head: MagCacheHook(state_manager, config, role=role) where role would be "head" or >"tail".
Or do you think it's easier to maintain two classes?
Regarding Single vs Two Classes: I considered merging them, but I ultimately kept them separate to distinguish between the 'Controller' role (HeadHook determining whether to skip) and the 'Worker' role (BlockHook executing the skip).
Merging them would likely require injecting conditional logic (e.g., if self.is_head: ...) inside the forward pass. Keeping them distinct also aligns with the existing design pattern in FirstBlockCache.
That said, I don't hold a strong opinion hereβif you prefer a unified class to keep the file smaller, I am happy to refactor it!"
@bot /style
Style bot fixed some files and pushed the changes.