Update teacache.py
Browse files- teacache.py +46 -32
teacache.py
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
|
15 |
import re
|
16 |
from dataclasses import dataclass, field
|
17 |
-
from typing import List, Optional
|
18 |
|
19 |
import numpy as np
|
20 |
import torch
|
@@ -72,12 +72,10 @@ _DEFAULT_SKIP_LAYER_IDENTIFIERS = [
|
|
72 |
@dataclass
|
73 |
class TeaCacheConfig:
|
74 |
l1_threshold: Optional[float] = None
|
75 |
-
|
76 |
-
# Julian: this is a fix for @dataclass
|
77 |
skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS.copy())
|
78 |
-
|
79 |
_polynomial_coefficients: Optional[List[float]] = None
|
80 |
-
|
81 |
timestep_modulated_layer_identifier: Optional[str] = None
|
82 |
|
83 |
|
@@ -112,13 +110,15 @@ def apply_teacache(
|
|
112 |
if denoiser is None:
|
113 |
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
114 |
|
115 |
-
#
|
116 |
-
|
|
|
117 |
logger.info(f"Detected model type: {model_type_name}")
|
118 |
|
|
|
119 |
supported_model_types = tuple(_MODEL_TO_POLY_COEFFICIENTS.keys())
|
120 |
is_supported_model = False
|
121 |
-
|
122 |
# Check if model_type is one of the supported types
|
123 |
for supported_type in supported_model_types:
|
124 |
if model_type == supported_type:
|
@@ -128,43 +128,57 @@ def apply_teacache(
|
|
128 |
if is_supported_model:
|
129 |
if config.l1_threshold is None:
|
130 |
logger.info(
|
131 |
-
f"No L1 threshold was provided for {
|
132 |
f"For higher speedup, increase the threshold."
|
133 |
)
|
134 |
-
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[
|
135 |
if config.timestep_modulated_layer_identifier is None:
|
136 |
logger.info(
|
137 |
-
f"No timestep modulated layer identifier was provided for {
|
138 |
)
|
139 |
-
|
|
|
|
|
|
|
|
|
140 |
if config._polynomial_coefficients is None:
|
141 |
logger.info(
|
142 |
-
f"No polynomial coefficients were provided for {
|
143 |
)
|
144 |
-
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[
|
145 |
else:
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
|
163 |
timestep_modulated_layer_matches = list(
|
164 |
{
|
165 |
module
|
166 |
for name, module in denoiser.named_modules()
|
167 |
-
if re.
|
168 |
}
|
169 |
)
|
170 |
|
@@ -190,7 +204,7 @@ def apply_teacache(
|
|
190 |
{
|
191 |
module
|
192 |
for name, module in denoiser.named_modules()
|
193 |
-
if any(re.
|
194 |
}
|
195 |
)
|
196 |
|
|
|
14 |
|
15 |
import re
|
16 |
from dataclasses import dataclass, field
|
17 |
+
from typing import List, Optional, Tuple, Dict, Type, Any, Union
|
18 |
|
19 |
import numpy as np
|
20 |
import torch
|
|
|
72 |
@dataclass
|
73 |
class TeaCacheConfig:
|
74 |
l1_threshold: Optional[float] = None
|
75 |
+
# Fixed with default_factory
|
|
|
76 |
skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS.copy())
|
|
|
77 |
_polynomial_coefficients: Optional[List[float]] = None
|
78 |
+
# Added missing field
|
79 |
timestep_modulated_layer_identifier: Optional[str] = None
|
80 |
|
81 |
|
|
|
110 |
if denoiser is None:
|
111 |
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
112 |
|
113 |
+
# Fix for isinstance() issue - convert to explicit type checking
|
114 |
+
model_type = denoiser.__class__
|
115 |
+
model_type_name = model_type.__name__
|
116 |
logger.info(f"Detected model type: {model_type_name}")
|
117 |
|
118 |
+
# Modified to avoid using isinstance with dict_keys
|
119 |
supported_model_types = tuple(_MODEL_TO_POLY_COEFFICIENTS.keys())
|
120 |
is_supported_model = False
|
121 |
+
|
122 |
# Check if model_type is one of the supported types
|
123 |
for supported_type in supported_model_types:
|
124 |
if model_type == supported_type:
|
|
|
128 |
if is_supported_model:
|
129 |
if config.l1_threshold is None:
|
130 |
logger.info(
|
131 |
+
f"No L1 threshold was provided for {model_type}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
|
132 |
f"For higher speedup, increase the threshold."
|
133 |
)
|
134 |
+
config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[model_type]
|
135 |
if config.timestep_modulated_layer_identifier is None:
|
136 |
logger.info(
|
137 |
+
f"No timestep modulated layer identifier was provided for {model_type}. Using default identifier as provided in the TeaCache paper."
|
138 |
)
|
139 |
+
if model_type in _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER:
|
140 |
+
config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[model_type]
|
141 |
+
else:
|
142 |
+
# Use a default for model types without a specific identifier
|
143 |
+
config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
|
144 |
if config._polynomial_coefficients is None:
|
145 |
logger.info(
|
146 |
+
f"No polynomial coefficients were provided for {model_type}. Using default coefficients as provided in the TeaCache paper."
|
147 |
)
|
148 |
+
config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[model_type]
|
149 |
else:
|
150 |
+
# Handle unsupported model types with more helpful defaults for LTX models
|
151 |
+
if model_type_name == "LTXVideoTransformer3DModel" or "LTX" in model_type_name:
|
152 |
+
logger.info(f"Model {model_type_name} appears to be an LTX model variant. Using LTX defaults.")
|
153 |
+
if config.l1_threshold is None:
|
154 |
+
config.l1_threshold = 0.05
|
155 |
+
if config.timestep_modulated_layer_identifier is None:
|
156 |
+
config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
|
157 |
+
if config._polynomial_coefficients is None:
|
158 |
+
config._polynomial_coefficients = [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03]
|
159 |
+
else:
|
160 |
+
if config.l1_threshold is None:
|
161 |
+
raise ValueError(
|
162 |
+
f"No L1 threshold was provided for {model_type}. Using TeaCache with this model is not supported "
|
163 |
+
f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
|
164 |
+
)
|
165 |
+
if config.timestep_modulated_layer_identifier is None:
|
166 |
+
raise ValueError(
|
167 |
+
f"No timestep modulated layer identifier was provided for {model_type}. Using TeaCache with this model is not supported "
|
168 |
+
f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
|
169 |
+
)
|
170 |
+
if config._polynomial_coefficients is None:
|
171 |
+
raise ValueError(
|
172 |
+
f"No polynomial coefficients were provided for {model_type}. Using TeaCache with this model is not "
|
173 |
+
f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
|
174 |
+
f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
|
175 |
+
)
|
176 |
|
177 |
timestep_modulated_layer_matches = list(
|
178 |
{
|
179 |
module
|
180 |
for name, module in denoiser.named_modules()
|
181 |
+
if re.search(config.timestep_modulated_layer_identifier, name)
|
182 |
}
|
183 |
)
|
184 |
|
|
|
204 |
{
|
205 |
module
|
206 |
for name, module in denoiser.named_modules()
|
207 |
+
if any(re.search(identifier, name) for identifier in skip_layer_identifiers)
|
208 |
}
|
209 |
)
|
210 |
|