Spaces:
Running
on
Zero
Running
on
Zero
Update auffusion_pipeline.py
Browse files- auffusion_pipeline.py +17 -7
auffusion_pipeline.py
CHANGED
@@ -469,17 +469,18 @@ class AuffusionPipeline(DiffusionPipeline):
|
|
469 |
vocoder: Generator = None,
|
470 |
requires_safety_checker: bool = False,
|
471 |
adapter_list: Optional[List[Callable]] = None,
|
472 |
-
tokenizer_model_max_length: Optional[int] = 77,
|
473 |
):
|
474 |
super().__init__()
|
475 |
-
|
|
|
476 |
self.text_encoder_list = text_encoder_list
|
477 |
self.tokenizer_list = tokenizer_list
|
478 |
-
self.vocoder = vocoder
|
479 |
self.adapter_list = adapter_list
|
|
|
480 |
self.tokenizer_model_max_length = tokenizer_model_max_length
|
481 |
-
|
482 |
-
# Register
|
483 |
self.register_modules(
|
484 |
vae=vae,
|
485 |
unet=unet,
|
@@ -488,10 +489,19 @@ class AuffusionPipeline(DiffusionPipeline):
|
|
488 |
feature_extractor=feature_extractor,
|
489 |
vocoder=vocoder if vocoder is not None else None,
|
490 |
)
|
491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
492 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
493 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
494 |
-
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
495 |
|
496 |
|
497 |
@classmethod
|
|
|
469 |
vocoder: Generator = None,
|
470 |
requires_safety_checker: bool = False,
|
471 |
adapter_list: Optional[List[Callable]] = None,
|
472 |
+
tokenizer_model_max_length: Optional[int] = 77,
|
473 |
):
|
474 |
super().__init__()
|
475 |
+
|
476 |
+
# Store list-based components and non-module fields as attributes
|
477 |
self.text_encoder_list = text_encoder_list
|
478 |
self.tokenizer_list = tokenizer_list
|
|
|
479 |
self.adapter_list = adapter_list
|
480 |
+
self.vocoder = vocoder # If it's a torch.nn.Module, you can still register it below
|
481 |
self.tokenizer_model_max_length = tokenizer_model_max_length
|
482 |
+
|
483 |
+
# Register torch modules only
|
484 |
self.register_modules(
|
485 |
vae=vae,
|
486 |
unet=unet,
|
|
|
489 |
feature_extractor=feature_extractor,
|
490 |
vocoder=vocoder if vocoder is not None else None,
|
491 |
)
|
492 |
+
|
493 |
+
# Register config-only (non-module) components — avoids ValueError during .to()
|
494 |
+
self.register_to_config(
|
495 |
+
requires_safety_checker=requires_safety_checker,
|
496 |
+
text_encoder_list=text_encoder_list,
|
497 |
+
tokenizer_list=tokenizer_list,
|
498 |
+
adapter_list=adapter_list,
|
499 |
+
tokenizer_model_max_length=tokenizer_model_max_length,
|
500 |
+
)
|
501 |
+
|
502 |
+
# Other logic
|
503 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
504 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
|
|
505 |
|
506 |
|
507 |
@classmethod
|