fffiloni commited on
Commit
096e320
·
verified ·
1 Parent(s): 58caec3

Update auffusion_pipeline.py

Browse files
Files changed (1) hide show
  1. auffusion_pipeline.py +17 -7
auffusion_pipeline.py CHANGED
@@ -469,17 +469,18 @@ class AuffusionPipeline(DiffusionPipeline):
469
  vocoder: Generator = None,
470
  requires_safety_checker: bool = False,
471
  adapter_list: Optional[List[Callable]] = None,
472
- tokenizer_model_max_length: Optional[int] = 77, # 77 is the default value for the CLIPTokenizer(and set for other models)
473
  ):
474
  super().__init__()
475
-
 
476
  self.text_encoder_list = text_encoder_list
477
  self.tokenizer_list = tokenizer_list
478
- self.vocoder = vocoder
479
  self.adapter_list = adapter_list
 
480
  self.tokenizer_model_max_length = tokenizer_model_max_length
481
-
482
- # Register only actual modules (not lists)
483
  self.register_modules(
484
  vae=vae,
485
  unet=unet,
@@ -488,10 +489,19 @@ class AuffusionPipeline(DiffusionPipeline):
488
  feature_extractor=feature_extractor,
489
  vocoder=vocoder if vocoder is not None else None,
490
  )
491
-
 
 
 
 
 
 
 
 
 
 
492
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
493
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
494
- self.register_to_config(requires_safety_checker=requires_safety_checker)
495
 
496
 
497
  @classmethod
 
469
  vocoder: Generator = None,
470
  requires_safety_checker: bool = False,
471
  adapter_list: Optional[List[Callable]] = None,
472
+ tokenizer_model_max_length: Optional[int] = 77,
473
  ):
474
  super().__init__()
475
+
476
+ # Store list-based components and non-module fields as attributes
477
  self.text_encoder_list = text_encoder_list
478
  self.tokenizer_list = tokenizer_list
 
479
  self.adapter_list = adapter_list
480
+ self.vocoder = vocoder # If it's a torch.nn.Module, you can still register it below
481
  self.tokenizer_model_max_length = tokenizer_model_max_length
482
+
483
+ # Register torch modules only
484
  self.register_modules(
485
  vae=vae,
486
  unet=unet,
 
489
  feature_extractor=feature_extractor,
490
  vocoder=vocoder if vocoder is not None else None,
491
  )
492
+
493
+ # Register config-only (non-module) components — avoids ValueError during .to()
494
+ self.register_to_config(
495
+ requires_safety_checker=requires_safety_checker,
496
+ text_encoder_list=text_encoder_list,
497
+ tokenizer_list=tokenizer_list,
498
+ adapter_list=adapter_list,
499
+ tokenizer_model_max_length=tokenizer_model_max_length,
500
+ )
501
+
502
+ # Other logic
503
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
504
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
505
 
506
 
507
  @classmethod