Sayoyo commited on
Commit
3b7cae0
·
1 Parent(s): 97264de

[fix] model load dtype

Browse files
Files changed (1) hide show
  1. pipeline_ace_step.py +2 -2
pipeline_ace_step.py CHANGED
@@ -143,7 +143,7 @@ class ACEStepPipeline:
143
  self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
144
  self.music_dcae.to(device).eval().to(self.dtype)
145
 
146
- self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
147
  self.ace_step_transformer.to(device).eval().to(self.dtype)
148
 
149
  lang_segment = LangSegment()
@@ -158,7 +158,7 @@ class ACEStepPipeline:
158
  ])
159
  self.lang_segment = lang_segment
160
  self.lyric_tokenizer = VoiceBpeTokenizer()
161
- text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path).eval()
162
  text_encoder_model = text_encoder_model.to(device).to(self.dtype)
163
  text_encoder_model.requires_grad_(False)
164
  self.text_encoder_model = text_encoder_model
 
143
  self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
144
  self.music_dcae.to(device).eval().to(self.dtype)
145
 
146
+ self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
147
  self.ace_step_transformer.to(device).eval().to(self.dtype)
148
 
149
  lang_segment = LangSegment()
 
158
  ])
159
  self.lang_segment = lang_segment
160
  self.lyric_tokenizer = VoiceBpeTokenizer()
161
+ text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
162
  text_encoder_model = text_encoder_model.to(device).to(self.dtype)
163
  text_encoder_model.requires_grad_(False)
164
  self.text_encoder_model = text_encoder_model