jbilcke-hf HF Staff commited on
Commit
624c83b
·
verified ·
1 Parent(s): 65e6ade

Update teacache.py

Browse files
Files changed (1) hide show
  1. teacache.py +46 -32
teacache.py CHANGED
@@ -14,7 +14,7 @@
14
 
15
  import re
16
  from dataclasses import dataclass, field
17
- from typing import List, Optional
18
 
19
  import numpy as np
20
  import torch
@@ -72,12 +72,10 @@ _DEFAULT_SKIP_LAYER_IDENTIFIERS = [
72
  @dataclass
73
  class TeaCacheConfig:
74
  l1_threshold: Optional[float] = None
75
-
76
- # Julian: this is a fix for @dataclass
77
  skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS.copy())
78
-
79
  _polynomial_coefficients: Optional[List[float]] = None
80
-
81
  timestep_modulated_layer_identifier: Optional[str] = None
82
 
83
 
@@ -112,13 +110,15 @@ def apply_teacache(
112
  if denoiser is None:
113
  denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
114
 
115
- # Get model type name for lookup
116
- model_type_name = denoiser.__class__.__name__
 
117
  logger.info(f"Detected model type: {model_type_name}")
118
 
 
119
  supported_model_types = tuple(_MODEL_TO_POLY_COEFFICIENTS.keys())
120
  is_supported_model = False
121
-
122
  # Check if model_type is one of the supported types
123
  for supported_type in supported_model_types:
124
  if model_type == supported_type:
@@ -128,43 +128,57 @@ def apply_teacache(
128
  if is_supported_model:
129
  if config.l1_threshold is None:
130
  logger.info(
131
- f"No L1 threshold was provided for {type(denoiser)}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
132
  f"For higher speedup, increase the threshold."
133
  )
134
- config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[type(denoiser)]
135
  if config.timestep_modulated_layer_identifier is None:
136
  logger.info(
137
- f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using default identifier as provided in the TeaCache paper."
138
  )
139
- config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[type(denoiser)]
 
 
 
 
140
  if config._polynomial_coefficients is None:
141
  logger.info(
142
- f"No polynomial coefficients were provided for {type(denoiser)}. Using default coefficients as provided in the TeaCache paper."
143
  )
144
- config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[type(denoiser)]
145
  else:
146
- if config.l1_threshold is None:
147
- raise ValueError(
148
- f"No L1 threshold was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
149
- f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
150
- )
151
- if config.timestep_modulated_layer_identifier is None:
152
- raise ValueError(
153
- f"No timestep modulated layer identifier was provided for {type(denoiser)}. Using TeaCache with this model is not supported "
154
- f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
155
- )
156
- if config._polynomial_coefficients is None:
157
- raise ValueError(
158
- f"No polynomial coefficients were provided for {type(denoiser)}. Using TeaCache with this model is not "
159
- f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
160
- f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
161
- )
 
 
 
 
 
 
 
 
 
 
162
 
163
  timestep_modulated_layer_matches = list(
164
  {
165
  module
166
  for name, module in denoiser.named_modules()
167
- if re.match(config.timestep_modulated_layer_identifier, name)
168
  }
169
  )
170
 
@@ -190,7 +204,7 @@ def apply_teacache(
190
  {
191
  module
192
  for name, module in denoiser.named_modules()
193
- if any(re.match(identifier, name) for identifier in skip_layer_identifiers)
194
  }
195
  )
196
 
 
14
 
15
  import re
16
  from dataclasses import dataclass, field
17
+ from typing import List, Optional, Tuple, Dict, Type, Any, Union
18
 
19
  import numpy as np
20
  import torch
 
72
  @dataclass
73
  class TeaCacheConfig:
74
  l1_threshold: Optional[float] = None
75
+ # Fixed with default_factory
 
76
  skip_layer_identifiers: List[str] = field(default_factory=lambda: _DEFAULT_SKIP_LAYER_IDENTIFIERS.copy())
 
77
  _polynomial_coefficients: Optional[List[float]] = None
78
+ # Added missing field
79
  timestep_modulated_layer_identifier: Optional[str] = None
80
 
81
 
 
110
  if denoiser is None:
111
  denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
112
 
113
+ # Fix for isinstance() issue - convert to explicit type checking
114
+ model_type = denoiser.__class__
115
+ model_type_name = model_type.__name__
116
  logger.info(f"Detected model type: {model_type_name}")
117
 
118
+ # Modified to avoid using isinstance with dict_keys
119
  supported_model_types = tuple(_MODEL_TO_POLY_COEFFICIENTS.keys())
120
  is_supported_model = False
121
+
122
  # Check if model_type is one of the supported types
123
  for supported_type in supported_model_types:
124
  if model_type == supported_type:
 
128
  if is_supported_model:
129
  if config.l1_threshold is None:
130
  logger.info(
131
+ f"No L1 threshold was provided for {model_type}. Using default threshold as provided in the TeaCache paper for 1.5x speedup. "
132
  f"For higher speedup, increase the threshold."
133
  )
134
+ config.l1_threshold = _MODEL_TO_1_POINT_5X_SPEEDUP_THRESHOLD[model_type]
135
  if config.timestep_modulated_layer_identifier is None:
136
  logger.info(
137
+ f"No timestep modulated layer identifier was provided for {model_type}. Using default identifier as provided in the TeaCache paper."
138
  )
139
+ if model_type in _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER:
140
+ config.timestep_modulated_layer_identifier = _MODEL_TO_TIMESTEP_MODULATED_LAYER_IDENTIFIER[model_type]
141
+ else:
142
+ # Use a default for model types without a specific identifier
143
+ config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
144
  if config._polynomial_coefficients is None:
145
  logger.info(
146
+ f"No polynomial coefficients were provided for {model_type}. Using default coefficients as provided in the TeaCache paper."
147
  )
148
+ config._polynomial_coefficients = _MODEL_TO_POLY_COEFFICIENTS[model_type]
149
  else:
150
+ # Handle unsupported model types with more helpful defaults for LTX models
151
+ if model_type_name == "LTXVideoTransformer3DModel" or "LTX" in model_type_name:
152
+ logger.info(f"Model {model_type_name} appears to be an LTX model variant. Using LTX defaults.")
153
+ if config.l1_threshold is None:
154
+ config.l1_threshold = 0.05
155
+ if config.timestep_modulated_layer_identifier is None:
156
+ config.timestep_modulated_layer_identifier = "transformer_blocks.0.norm1"
157
+ if config._polynomial_coefficients is None:
158
+ config._polynomial_coefficients = [2.14700694e01, -1.28016453e01, 2.31279151e00, 7.92487521e-01, 9.69274326e-03]
159
+ else:
160
+ if config.l1_threshold is None:
161
+ raise ValueError(
162
+ f"No L1 threshold was provided for {model_type}. Using TeaCache with this model is not supported "
163
+ f"in Diffusers. Please provide the L1 threshold in the config by setting the `l1_threshold` attribute."
164
+ )
165
+ if config.timestep_modulated_layer_identifier is None:
166
+ raise ValueError(
167
+ f"No timestep modulated layer identifier was provided for {model_type}. Using TeaCache with this model is not supported "
168
+ f"in Diffusers. Please provide the layer identifier in the config by setting the `timestep_modulated_layer_identifier` attribute."
169
+ )
170
+ if config._polynomial_coefficients is None:
171
+ raise ValueError(
172
+ f"No polynomial coefficients were provided for {model_type}. Using TeaCache with this model is not "
173
+ f"supported in Diffusers. Please provide the polynomial coefficients in the config by setting the "
174
+ f"`_polynomial_coefficients` attribute. Automatic calibration will be implemented in the future."
175
+ )
176
 
177
  timestep_modulated_layer_matches = list(
178
  {
179
  module
180
  for name, module in denoiser.named_modules()
181
+ if re.search(config.timestep_modulated_layer_identifier, name)
182
  }
183
  )
184
 
 
204
  {
205
  module
206
  for name, module in denoiser.named_modules()
207
+ if any(re.search(identifier, name) for identifier in skip_layer_identifiers)
208
  }
209
  )
210