aryo100 commited on
Commit
9d89c87
·
1 Parent(s): 8b6e5c4

update models.py & add model

Browse files
modules/models.py CHANGED
@@ -184,7 +184,8 @@ class VoiceConvertModel:
184
  )
185
  os.makedirs(output_dir, exist_ok=True)
186
  input_audio_splitext = os.path.splitext(os.path.basename(input_audio))[0]
187
- model_splitext = os.path.splitext(self.model_name)[0]
 
188
  index = 0
189
  existing_files = os.listdir(output_dir)
190
  for existing_file in existing_files:
@@ -202,7 +203,33 @@ class VoiceConvertModel:
202
  return audio_opt
203
 
204
  def get_index_path(self, speaker_id: int):
205
- basename = os.path.splitext(self.model_name)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  speaker_index_path = os.path.join(
207
  MODELS_DIR,
208
  "checkpoints",
@@ -239,10 +266,27 @@ def get_embedder(embedder_name):
239
  def load_embedder(emb_file: str, emb_name: str):
240
  global embedder_model, loaded_embedder_model
241
  emb_file = os.path.join(MODELS_DIR, "embeddings", emb_file)
242
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
243
- [emb_file],
244
- suffix="",
245
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  embedder_model = models[0]
247
  embedder_model = embedder_model.to(device)
248
 
@@ -256,8 +300,23 @@ def load_embedder(emb_file: str, emb_name: str):
256
 
257
 
258
  def get_vc_model(model_name: str):
259
- model_path = os.path.join(MODELS_DIR, "checkpoints", model_name)
260
- weight = torch.load(model_path, map_location="cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  return VoiceConvertModel(model_name, weight)
262
 
263
 
 
184
  )
185
  os.makedirs(output_dir, exist_ok=True)
186
  input_audio_splitext = os.path.splitext(os.path.basename(input_audio))[0]
187
+ # model_splitext = os.path.splitext(self.model_name)[0]
188
+ model_splitext = os.path.splitext(os.path.basename(self.model_name))[0]
189
  index = 0
190
  existing_files = os.listdir(output_dir)
191
  for existing_file in existing_files:
 
203
  return audio_opt
204
 
205
  def get_index_path(self, speaker_id: int):
206
+ # basename = os.path.splitext(self.model_name)[0]
207
+ basename = os.path.splitext(os.path.basename(self.model_name))[0]
208
+
209
+ # Handle relative paths (e.g., "weights/zet_test1.pth")
210
+ if "/" in self.model_name:
211
+ # It's a relative path, look for index in the same directory
212
+ model_dir = os.path.dirname(os.path.join(ROOT_DIR, self.model_name))
213
+ relative_index_path = os.path.join(model_dir, f"{basename}.{speaker_id}.index")
214
+ if os.path.exists(relative_index_path):
215
+ return relative_index_path
216
+
217
+ # Check generic index in same directory
218
+ relative_generic_index = os.path.join(model_dir, f"{basename}.index")
219
+ if os.path.exists(relative_generic_index):
220
+ return relative_generic_index
221
+
222
+ # Check in weights folder first (for custom models)
223
+ weights_dir = os.path.join(ROOT_DIR, "weights")
224
+ weights_index_path = os.path.join(weights_dir, f"{basename}.{speaker_id}.index")
225
+ if os.path.exists(weights_index_path):
226
+ return weights_index_path
227
+
228
+ # Check generic index in weights folder
229
+ weights_generic_index = os.path.join(weights_dir, f"{basename}.index")
230
+ if os.path.exists(weights_generic_index):
231
+ return weights_generic_index
232
+
233
  speaker_index_path = os.path.join(
234
  MODELS_DIR,
235
  "checkpoints",
 
266
  def load_embedder(emb_file: str, emb_name: str):
267
  global embedder_model, loaded_embedder_model
268
  emb_file = os.path.join(MODELS_DIR, "embeddings", emb_file)
269
+
270
+ # Monkey patch torch.load to use weights_only=False for fairseq compatibility
271
+ original_torch_load = torch.load
272
+ def patched_torch_load(*args, **kwargs):
273
+ kwargs['weights_only'] = False
274
+ return original_torch_load(*args, **kwargs)
275
+
276
+ try:
277
+ torch.load = patched_torch_load
278
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
279
+ [emb_file],
280
+ suffix="",
281
+ )
282
+ finally:
283
+ # Restore original torch.load
284
+ torch.load = original_torch_load
285
+
286
+ # models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
287
+ # [emb_file],
288
+ # suffix="",
289
+ # )
290
  embedder_model = models[0]
291
  embedder_model = embedder_model.to(device)
292
 
 
300
 
301
 
302
  def get_vc_model(model_name: str):
303
+ if "/" in model_name:
304
+ # It's a relative path, use it directly
305
+ model_path = os.path.join(ROOT_DIR, model_name)
306
+ if not os.path.exists(model_path):
307
+ raise FileNotFoundError(f"Model file not found: {model_path}")
308
+ else:
309
+ # It's just a filename, check in weights folder first (for custom models)
310
+ weights_path = os.path.join(ROOT_DIR, "weights", model_name)
311
+ if os.path.exists(weights_path):
312
+ model_path = weights_path
313
+ else:
314
+ # Fallback to checkpoints folder
315
+ model_path = os.path.join(MODELS_DIR, "checkpoints", model_name)
316
+
317
+ weight = torch.load(model_path, map_location="cpu", weights_only=False)
318
+ # model_path = os.path.join(MODELS_DIR, "checkpoints", model_name)
319
+ # weight = torch.load(model_path, map_location="cpu")
320
  return VoiceConvertModel(model_name, weight)
321
 
322
 
weights/zet_test1.0.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a194ba96c803cb1530676d495d5bcab6224a460aa535fc1b8bbe4806d5d9c186
3
+ size 3234139
weights/zet_test1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:417723f2279e91ffd8fa85a555115cb36b796c898d6a25c25394596efc303382
3
+ size 55192835