diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Implement TeaCache

Open LawJarp-A opened this issue 1 month ago • 14 comments

What does this PR do?

What is TeaCache?

TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Architecture

┌─────────────────────────────────────────────────────────────┐
│                    TeaCache Hook Flow                       │
├─────────────────────────────────────────────────────────────┤
│  1. Extract modulated input from first transformer block    │
│  2. Compute relative L1 distance vs previous timestep       │
│  3. Apply model-specific polynomial rescaling               │
│  4. Accumulate distance, compare to threshold               │
│                                                             │
│  If accumulated < threshold → Reuse cached residual (FAST)  │
│  If accumulated ≥ threshold → Full transformer pass (SLOW)  │
└─────────────────────────────────────────────────────────────┘

Integrates with existing HookRegistry and CacheMixin patterns in diffusers.

Supported Models

Model Coefficients Status
FLUX Tested
FLUX-Kontext Ready
Mochi Ready
Lumina2 Ready
CogVideoX (2b/5b/1.5) Ready

Benchmark Results (FLUX.1-schnell, 20 steps, 512x512)

Threshold Time Speedup
Baseline 3.76s 1.00x
0.25 2.28s 1.65x
0.4 1.86s 2.02x
0.8 1.30s 2.89x
image

Benchmark Results (Lumina2, 28 steps, 512x512)

Threshold Time Speedup
Baseline 4.33s 1.00x
0.25 2.22s 1.95x
0.4 1.79s 2.42x
0.8 1.43s 3.02x
image

Benchmark Results (CogVideoX-2b, 50 steps, 720x720, 49 frames)

Threshold Time Speedup
Baseline 91.96s 1.00x
0.25 37.98s 2.42x
0.4 30.39s 3.03x
0.8 24.30s 3.78x
image

Test Hardware: NVIDIA A100-SXM4-40GB
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility

Pros & Cons

Pros:

  • Training-free, drop-in speedup
  • Works with existing pipelines via enable_teacache()
  • Configurable quality/speed tradeoff
  • Proper state management between inference runs

Cons:

  • Quality degrades at high thresholds (>0.6)
  • Model-specific coefficients required

Usage

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Enable TeaCache (1.75x speedup with 0.4 threshold)
pipe.transformer.enable_teacache(rel_l1_thresh=0.4, num_inference_steps=20)

image = pipe("A dragon on a crystal mountain", num_inference_steps=20).images[0]

pipe.transformer.disable_cache()

Files Changed

  • src/diffusers/hooks/teacache.py - Core implementation
  • src/diffusers/models/cache_utils.py - CacheMixin integration
  • tests/hooks/test_teacache.py - Unit tests

Fixes # (issue) https://github.com/huggingface/diffusers/issues/12589 https://github.com/huggingface/diffusers/issues/12635

Before submitting

Who can review?

@sayakpaul @yiyixuxu

LawJarp-A avatar Nov 13 '25 14:11 LawJarp-A

Work done

  • [X] Implement teacache for FLUX architecture using hooks (only flux for now)
  • [X] add logging
  • [X] add compatible tests

Waiting for feedback and review :) cc: @dhruvrnaik @sayakpaul @yiyixuxu

LawJarp-A avatar Nov 13 '25 17:11 LawJarp-A

Hi @sayakpaul @dhruvrnaik any updates?

LawJarp-A avatar Nov 17 '25 11:11 LawJarp-A

@LawJarp-A sorry about the delay on our end. @DN6 will review it soon.

sayakpaul avatar Nov 23 '25 11:11 sayakpaul

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

DN6 avatar Nov 24 '25 02:11 DN6

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well. On the same note, lemme know if there is anything to add to the current implementation

LawJarp-A avatar Nov 24 '25 03:11 LawJarp-A

@DN6 updated it in a more model agnostic way. Requesting review and feedback

LawJarp-A avatar Nov 26 '25 08:11 LawJarp-A

Added multi model support, testing it thoroughly though.

LawJarp-A avatar Dec 02 '25 09:12 LawJarp-A

Hi @DN6 @sayakpaul Two questions, I'm almost done testing, I'll update the PR with more descriptive results and changes. And do final cleanup/merging etc

  1. Any tests I should write and anything I can refer to for the same?
  2. Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

In the meantime any feedback would be appreciated

LawJarp-A avatar Dec 08 '25 11:12 LawJarp-A

Thanks @LawJarp-A!

Any tests I should write and anything I can refer to for the same?

You can refer to https://github.com/huggingface/diffusers/issues/12569 for testing

Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

Yes, I think that is informative for users.

sayakpaul avatar Dec 08 '25 12:12 sayakpaul

I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

t was fine when I wrote for flux, but lumina needed multi stage preprocessing. I am trying to think how to , but keeping a generic forward might not work very well :/ Firstcache, FirstBlock all work block level, but TeaCache is more model level. Defo open to ideas :)

LawJarp-A avatar Dec 08 '25 13:12 LawJarp-A

@sayakpaul Added flux image example in the PR description. Tested it with Lumina, CogVideoX as well Could not test with Mochi because of GPU constraints. I can try with cpu offloading maybe

LawJarp-A avatar Dec 10 '25 11:12 LawJarp-A

@sayakpaul @DN6 I got the core logic working, and tested it for model my GPU can handle Right now I have gone for a simple monolithic method, each of the models forward handlers, extractors all in one file. I tried to abstract it as much, but since TeaCache works on model level, rather than blocks (like most of the caches right now, taylor, firstblock etc). It's proven a bit difficult to make it model agnostic.

The current implementation puts all model handlers in a single teacache.py file. This works but has scaling concerns: I was thinking, since we have to add model specific functions anyway, make them a bit modular deisgn-wise.

Potential refactor: Registry + Handler pattern

diffusers/hooks/
├── teacache/
│   ├── __init__.py           # Public API
│   ├── config.py             # TeaCacheConfig
│   ├── hook.py               # TeaCacheHook (core logic)
│   ├── registry.py           # Handler registry
│   └── handlers/
│       ├── __init__.py       # Auto-imports all handlers
│       ├── base.py           # BaseTeaCacheHandler ABC
│       ├── flux.py
│       ├── mochi.py
│       ├── lumina2.py
│       └── cogvideox.py

Each handler self-registers and encapsulates its logic:

# handlers/flux.py
from .base import BaseTeaCacheHandler
from ..registry import register_handler

@register_handler("Flux", "FluxKontext")
class FluxHandler(BaseTeaCacheHandler):
    coefficients = [4.98651651e02, -2.83781631e02, ...]
    
    def extract_modulated_input(self, module, hidden_states, temb):
        return module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
    
    def handle_forward(self, module, *args, **kwargs):
        # FLUX-specific forward with ControlNet, LORA, etc.
        ...
# registry.py
_HANDLER_REGISTRY = {}

def register_handler(*model_names):
    def decorator(cls):
        for name in model_names:
            _HANDLER_REGISTRY[name] = cls
        return cls
    return decorator

def get_handler(module) -> BaseTeaCacheHandler:
    for name, handler_cls in _HANDLER_REGISTRY.items():
        if name in module.__class__.__name__:
            return handler_cls()
    raise ValueError(f"No TeaCache handler for {module.__class__.__name__}")

This is similar to how attention processors and schedulers are organized. Happy to refactor if you think it's worth it, or we can keep it simple like now. Since this has proven a bit more of a challenge to integrate than I thought xD would be happy to know if you guys have some ideas.

LawJarp-A avatar Dec 11 '25 09:12 LawJarp-A

Hey @DN6 @sayakpaul , any updates :)

LawJarp-A avatar Dec 15 '25 04:12 LawJarp-A