Skip to content

Implement TeaCache #12652

Open
LawJarp-A wants to merge 38 commits intohuggingface:mainfrom
LawJarp-A:teacache-flux
Open

Implement TeaCache #12652
LawJarp-A wants to merge 38 commits intohuggingface:mainfrom
LawJarp-A:teacache-flux

Conversation

@LawJarp-A
Copy link

@LawJarp-A LawJarp-A commented Nov 13, 2025

What does this PR do?

What is TeaCache?

TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Architecture & Design

TeaCache uses a ModelHook to intercept transformer forward passes without modifying model code. The algorithm:

  1. Extracts modulated input from first transformer block (after norm1 + timestep embedding)
  2. Computes relative L1 distance vs previous timestep
  3. Applies model-specific polynomial rescaling: c[0]*x^4 + c[1]*x^3 + c[2]*x^2 + c[3]*x + c[4]
  4. Accumulates rescaled distance across timesteps
  5. If accumulated < threshold → Reuses cached residual (FAST)
  6. If accumulated >= threshold → Full transformer pass (SLOW, update cache)

Key Design Features:

  • Hook-based: Integrates with HookRegistry and CacheMixin for lifecycle management
  • State Isolation: StateManager with context-aware state for CFG conditional/unconditional branches
  • Model Auto-Detection: Detects model type from class name and config path (specific variants checked first)
  • Boundary Guarantee: First and last timesteps always computed fully for quality
  • Specialized Strategies: Dual residual caching (CogVideoX), per-sequence-length caching (Lumina2)

Supported Models

Model Coefficients Status
FLUX Auto-detected Tested
FLUX-Kontext Auto-detected Ready
Mochi Auto-detected Ready
Lumina2 Auto-detected Ready
CogVideoX (2b/5b/1.5-5B) Auto-detected Ready

All models support automatic coefficient detection based on model class name and config path. Custom coefficients can also be provided via TeaCacheConfig.


Benchmark Results (FLUX.1-dev)

Threshold Time Speedup
Baseline 9.26s 1.00x
0.2 6.85s 1.35x
0.4 5.24s 1.77x
0.6 4.64s 2.00x
0.8 4.18s 2.22x

Benchmark Results (Lumina2)

Threshold Time Speedup
Baseline 3.45s 1.00x
0.2 3.07s 1.12x
0.4 2.27s 1.52x
0.6 1.84s 1.88x

Benchmark Results (CogVideoX-2b)

Threshold Time Speedup
Baseline 26.27s 1.00x
0.3 23.97s 1.10x
0.5 22.57s 1.16x
0.7 19.31s 1.36x
0.9 17.38s 1.51x

Benchmark Results (Mochi)

Threshold Time Speedup
Baseline 7.71s 1.00x
0.05 6.27s 1.23x
0.06 6.03s 1.28x
0.08 5.73s 1.35x
0.10 5.41s 1.42x

Test Hardware: NVIDIA h100
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility

Usage

from diffusers import FluxPipeline
from diffusers.hooks import TeaCacheConfig

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Enable TeaCache (1.75x speedup with 0.4 threshold)
config = TeaCacheConfig(rel_l1_thresh=0.4)
pipe.transformer.enable_cache(config)

image = pipe("A dragon on a crystal mountain", num_inference_steps=20).images[0]

pipe.transformer.disable_cache()

Configuration Options

The TeaCacheConfig supports the following parameters:

  • rel_l1_thresh (float, default=0.2): Threshold for accumulated relative L1 distance. Recommended values: 0.25 for ~1.5x speedup, 0.4 for ~1.8x, 0.6 for ~2.0x. Mochi models require lower thresholds (0.06-0.09).
  • coefficients (List[float], optional): Polynomial coefficients for rescaling L1 distance. Auto-detected based on model type if not provided.
  • num_inference_steps (int, optional): Total inference steps. Ensures first/last timesteps are always computed. Auto-detected if not provided.
  • num_inference_steps_callback (Callable[[], int], optional): Callback returning total inference steps. Alternative to num_inference_steps.
  • current_timestep_callback (Callable[[], int], optional): Callback returning current timestep. Used for debugging/statistics.

Files Changed

  • src/diffusers/hooks/teacache.py - Core implementation with model-specific forward functions
  • src/diffusers/models/cache_utils.py - CacheMixin integration
  • src/diffusers/hooks/__init__.py - Export TeaCacheConfig and apply_teacache
  • tests/hooks/test_teacache.py - Comprehensive unit tests

Fixes # (issue)
#12589
#12635

Before submitting

Who can review?

@sayakpaul @yiyixuxu @DN6

@sayakpaul sayakpaul requested a review from DN6 November 13, 2025 16:49
@LawJarp-A
Copy link
Author

LawJarp-A commented Nov 13, 2025

Work done

  • Implement teacache for FLUX architecture using hooks (only flux for now)
  • add logging
  • add compatible tests

Waiting for feedback and review :)
cc: @dhruvrnaik @sayakpaul @yiyixuxu

@LawJarp-A LawJarp-A marked this pull request as ready for review November 14, 2025 08:23
@LawJarp-A
Copy link
Author

Hi @sayakpaul @dhruvrnaik any updates?

@sayakpaul
Copy link
Member

@LawJarp-A sorry about the delay on our end. @DN6 will review it soon.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@DN6
Copy link
Collaborator

DN6 commented Nov 24, 2025

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

@LawJarp-A
Copy link
Author

Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this.

Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well.
On the same note, lemme know if there is anything to add to the current implementation

@LawJarp-A
Copy link
Author

LawJarp-A commented Nov 26, 2025

@DN6 updated it in a more model agnostic way.
Requesting review and feedback

@LawJarp-A
Copy link
Author

Added multi model support, testing it thoroughly though.

@LawJarp-A
Copy link
Author

Hi @DN6 @sayakpaul
Two questions, I'm almost done testing, I'll update the PR with more descriptive results and changes. And do final cleanup/merging etc

  1. Any tests I should write and anything I can refer to for the same?
  2. Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

In the meantime any feedback would be appreciated

@sayakpaul
Copy link
Member

Thanks @LawJarp-A!

Any tests I should write and anything I can refer to for the same?

You can refer to #12569 for testing

Added support for other models, I'll add pictures comparison with speedup and threshold to the PR as well?

Yes, I think that is informative for users.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

@LawJarp-A
Copy link
Author

LawJarp-A commented Dec 8, 2025

I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe

Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?

t was fine when I wrote for flux, but lumina needed multi stage preprocessing.
I am trying to think how to , but keeping a generic forward might not work very well :/
Firstcache, FirstBlock all work block level, but TeaCache is more model level.
Defo open to ideas :)

@LawJarp-A
Copy link
Author

The per-model forward code is unavoidable due to different model architectures. The adapter pattern was an attempt to organize this, but I agree standalone functions would be cleaner. I'll refactor.

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…ctions

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
@LawJarp-A
Copy link
Author

Hi @DN6 , I've updated the implementation as you requested:

  • Replaced adapter classes with standalone forward functions
  • Created _MODEL_CONFIG mapping for forward functions and coefficients
  • Removed cache_fn/compute_fn closures - now using direct if/else logic in each forward
  • Extracted utility functions: _should_compute(), _update_state(), _apply_cached_residual()
  • Removed enable_teacache() - now only enable_cache(TeaCacheConfig(...))
  • Inlined modulation extractors into forward functions

This does introduce some code duplication - each forward function now has the same if/else pattern:

  if _should_compute(state, modulated_inp, hook.coefficients, hook.config.rel_l1_thresh):
      # compute full transformer
      _update_state(state, output, original, modulated_inp)
  else:
      output = _apply_cached_residual(state, input, modulated_inp)

But the control flow is now much clearer - you can read each forward function top-to-bottom without jumping between closures and hook methods.

Let me know if you'd like any further changes!

@sayakpaul sayakpaul requested a review from DN6 January 8, 2026 06:33
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
… isolation

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
…elpers

Signed-off-by: Prajwal A <prajwalanagani@gmail.com>
@LawJarp-A
Copy link
Author

LawJarp-A commented Jan 12, 2026

@DN6 @sayakpaul I spent the weekend going over the code again to understand and simplify

  • I have updated the cache context to be set in the denosing loop itself
  • removed redundant code
  • tested it with all models on a h100 and updated it in the PR description

I have kept it with per model forward function like you requested instead of the common adapter pattern I was using before.
Please review it now, I think it addresses all the recent feedback I have recieved

Btw, below are the images generated w and w/o cache

Mochi
image

lumina2
image

flux
image

cogxvideo
image

LawJarp-A and others added 2 commits January 12, 2026 16:31
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments. LMK if they make sense.

Comment on lines +45 to +47
# Fallback to default context for backward compatibility with
# pipelines that don't call cache_context()
context = "_default"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this branch not error out like previous?

Comment on lines +109 to +110
if prev_mean.item() > 1e-9:
return ((current - previous).abs().mean() / prev_mean).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to make it data-dependent (item() call)? Raising it because it makes torch.compile cry.

Comment on lines 547 to 549
attention_kwargs, lora_scale = _extract_lora_scale(attention_kwargs)
if USE_PEFT_BACKEND:
scale_lora_layers(module, lora_scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check if the underlying model class inherits from PeftLoaderMixin and if so, we should do it.

@sayakpaul sayakpaul requested a review from Copilot January 20, 2026 09:09
@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Jan 20, 2026

Style bot fixed some files and pushed the changes.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements TeaCache (Timestep Embedding Aware Cache), a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.

Changes:

  • Adds TeaCache hook system with model-specific forward implementations for FLUX, Mochi, Lumina2, and CogVideoX models
  • Integrates TeaCache with the existing CacheMixin infrastructure for unified cache management
  • Implements StateManager improvements for context-aware state isolation (CFG support)

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
src/diffusers/hooks/teacache.py Core TeaCache implementation with polynomial rescaling, model auto-detection, and specialized forward functions for each supported model
src/diffusers/models/cache_utils.py Integration of TeaCacheConfig into enable_cache/disable_cache methods
src/diffusers/hooks/init.py Export TeaCacheConfig, apply_teacache, and StateManager
src/diffusers/hooks/hooks.py StateManager enhancement with default context fallback for backward compatibility
src/diffusers/models/transformers/transformer_lumina2.py Add CacheMixin to Lumina2Transformer2DModel
tests/hooks/test_teacache.py Comprehensive unit tests for config validation, state management, and model detection

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +231 to +237
f"Please provide a float value between 0.1 and 1.0."
)
if self.rel_l1_thresh <= 0:
raise ValueError(
f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. "
f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. "
f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup."
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation logic checks for rel_l1_thresh <= 0 but zero values are arguably valid since they would force computation at every step (effectively disabling caching). Consider whether the check should be < 0 instead, or document why zero is explicitly disallowed.

Suggested change
f"Please provide a float value between 0.1 and 1.0."
)
if self.rel_l1_thresh <= 0:
raise ValueError(
f"rel_l1_thresh must be positive, got {self.rel_l1_thresh}. "
f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. "
f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup."
f"Please provide a float value >= 0.0 (values between 0.1 and 1.0 are recommended)."
)
if self.rel_l1_thresh < 0:
raise ValueError(
f"rel_l1_thresh must be non-negative, got {self.rel_l1_thresh}. "
f"Based on the TeaCache paper, values between 0.1 and 0.3 work best. "
f"Try 0.25 for 1.5x speedup or 0.6 for 2x speedup. "
f"Note that rel_l1_thresh=0.0 effectively disables caching by forcing computation at every step."

Copilot uses AI. Check for mistakes.
state.cnt = 0
state.accumulated_rel_l1_distance = 0.0
state.previous_modulated_input = None
state.previous_residual = None
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _maybe_reset_state_for_new_inference method doesn't reset cache_dict and uncond_seq_len which are used by Lumina2. This could cause stale cache data to persist across inference runs when using Lumina2 models. Consider calling state.reset() instead of manually resetting individual fields, or add these Lumina2-specific fields to the reset logic.

Suggested change
state.previous_residual = None
state.previous_residual = None
# Reset Lumina2-specific state to avoid stale cache/data between inference runs
if hasattr(state, "cache_dict") and state.cache_dict is not None:
# Clear in-place to preserve any existing references to the cache dict
state.cache_dict.clear()
if hasattr(state, "uncond_seq_len"):
state.uncond_seq_len = None

Copilot uses AI. Check for mistakes.
Comment on lines 200 to 213
Example:
```python
>>> from diffusers import FluxPipeline
>>> from diffusers.hooks import TeaCacheConfig

>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")

>>> config = TeaCacheConfig(rel_l1_thresh=0.2)
>>> pipe.transformer.enable_cache(config)

>>> image = pipe("A cat sitting on a windowsill", num_inference_steps=4).images[0]
>>> pipe.transformer.disable_cache()
```
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example code in the docstring references torch but doesn't show the import statement. Consider adding import torch to the example for completeness.

Copilot uses AI. Check for mistakes.
@LawJarp-A
Copy link
Author

Thanks for the review. Taking a look

@sayakpaul
Copy link
Member

sayakpaul commented Feb 16, 2026

@LawJarp-A I am guessing the Copilot review comments were resolved? There also seems to be a couple of unresolved comments.

@sayakpaul
Copy link
Member

Cc: @LiewFeng would you like to give this a review as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Comments