Spaces:
Running
Running
[fix] model load dtype
Browse files- pipeline_ace_step.py +2 -2
pipeline_ace_step.py
CHANGED
@@ -143,7 +143,7 @@ class ACEStepPipeline:
|
|
143 |
self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
|
144 |
self.music_dcae.to(device).eval().to(self.dtype)
|
145 |
|
146 |
-
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
|
147 |
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
148 |
|
149 |
lang_segment = LangSegment()
|
@@ -158,7 +158,7 @@ class ACEStepPipeline:
|
|
158 |
])
|
159 |
self.lang_segment = lang_segment
|
160 |
self.lyric_tokenizer = VoiceBpeTokenizer()
|
161 |
-
text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path).eval()
|
162 |
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
163 |
text_encoder_model.requires_grad_(False)
|
164 |
self.text_encoder_model = text_encoder_model
|
|
|
143 |
self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
|
144 |
self.music_dcae.to(device).eval().to(self.dtype)
|
145 |
|
146 |
+
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
|
147 |
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
148 |
|
149 |
lang_segment = LangSegment()
|
|
|
158 |
])
|
159 |
self.lang_segment = lang_segment
|
160 |
self.lyric_tokenizer = VoiceBpeTokenizer()
|
161 |
+
text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
|
162 |
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
163 |
text_encoder_model.requires_grad_(False)
|
164 |
self.text_encoder_model = text_encoder_model
|