kevinwang676 commited on
Commit
946de3e
·
verified ·
1 Parent(s): 10d4c29

Update api_v2.py

Browse files
Files changed (1) hide show
  1. api_v2.py +102 -6
api_v2.py CHANGED
@@ -192,6 +192,56 @@ def process_audio_path(audio_path) -> Tuple[str, bool]:
192
  # If not a URL or download failed, return the original path
193
  return audio_path, False # Not a temporary file
194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
196
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
197
  parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
@@ -543,25 +593,71 @@ async def set_refer_aduio(refer_audio_path: str = None):
543
 
544
  @APP.get("/set_gpt_weights")
545
  async def set_gpt_weights(weights_path: str = None):
 
546
  try:
547
  if weights_path in ["", None]:
548
  return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
549
- tts_pipeline.init_t2s_weights(weights_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  except Exception as e:
 
 
 
 
 
 
 
 
551
  return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
552
 
553
- return JSONResponse(status_code=200, content={"message": "success"})
554
-
555
-
556
  @APP.get("/set_sovits_weights")
557
  async def set_sovits_weights(weights_path: str = None):
 
558
  try:
559
  if weights_path in ["", None]:
560
  return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
561
- tts_pipeline.init_vits_weights(weights_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  except Exception as e:
 
 
 
 
 
 
 
 
563
  return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
564
- return JSONResponse(status_code=200, content={"message": "success"})
565
 
566
 
567
 
 
192
  # If not a URL or download failed, return the original path
193
  return audio_path, False # Not a temporary file
194
 
195
+ # Function to process weight files (similar to process_audio_path)
196
+ def process_weights_path(weights_path) -> Tuple[str, bool]:
197
+ """
198
+ Process a weights path, downloading it if it's a URL.
199
+
200
+ Args:
201
+ weights_path (str): Path or URL to weights file
202
+
203
+ Returns:
204
+ Tuple[str, bool]: (local_path, is_temporary)
205
+ """
206
+ if weights_path and (weights_path.startswith('http://') or weights_path.startswith('https://') or
207
+ weights_path.startswith('s3://')):
208
+ try:
209
+ # Create temp directory if it doesn't exist
210
+ temp_dir = os.path.join(now_dir, "temp_weights")
211
+ os.makedirs(temp_dir, exist_ok=True)
212
+
213
+ # Generate a filename from the URL
214
+ parsed_url = urllib.parse.urlparse(weights_path)
215
+ filename = os.path.basename(parsed_url.path)
216
+ if not filename:
217
+ filename = f"temp_weights_{hash(weights_path)}.pth"
218
+
219
+ # Full path for downloaded file
220
+ local_path = os.path.join(temp_dir, filename)
221
+
222
+ # Download file
223
+ if weights_path.startswith('s3://'):
224
+ # S3 implementation placeholder
225
+ print(f"Downloading from S3: {weights_path}")
226
+ raise NotImplementedError("S3 download not implemented. Add boto3 library and implementation.")
227
+ else:
228
+ # HTTP/HTTPS download
229
+ print(f"Downloading weights from URL: {weights_path}")
230
+ response = requests.get(weights_path, stream=True)
231
+ response.raise_for_status()
232
+ with open(local_path, 'wb') as f:
233
+ for chunk in response.iter_content(chunk_size=8192):
234
+ f.write(chunk)
235
+
236
+ print(f"Downloaded weights to: {local_path}")
237
+ return local_path, True # Return path and flag indicating it's temporary
238
+ except Exception as e:
239
+ print(f"Error downloading weights file: {e}")
240
+ raise Exception(f"Failed to download weights from URL: {e}")
241
+
242
+ # If not a URL or download failed, return the original path
243
+ return weights_path, False # Not a temporary file
244
+
245
  parser = argparse.ArgumentParser(description="GPT-SoVITS api")
246
  parser.add_argument("-c", "--tts_config", type=str, default="GPT_SoVITS/configs/tts_infer.yaml", help="tts_infer路径")
247
  parser.add_argument("-a", "--bind_addr", type=str, default="127.0.0.1", help="default: 127.0.0.1")
 
593
 
594
  @APP.get("/set_gpt_weights")
595
  async def set_gpt_weights(weights_path: str = None):
596
+ temp_file = None
597
  try:
598
  if weights_path in ["", None]:
599
  return JSONResponse(status_code=400, content={"message": "gpt weight path is required"})
600
+
601
+ # Process the path (download if it's a URL)
602
+ local_path, is_temp = process_weights_path(weights_path)
603
+ if is_temp:
604
+ temp_file = local_path
605
+
606
+ # Load the weights
607
+ tts_pipeline.init_t2s_weights(local_path)
608
+
609
+ # Clean up if it was a temporary file
610
+ # Note: Depending on how init_t2s_weights works, you might need to keep the file
611
+ # If the function loads the file into memory, you can delete it right away
612
+ if temp_file and os.path.exists(temp_file):
613
+ os.remove(temp_file)
614
+ print(f"Removed temporary weights file: {temp_file}")
615
+
616
+ return JSONResponse(status_code=200, content={"message": "success"})
617
  except Exception as e:
618
+ # Clean up temp file in case of error
619
+ if temp_file and os.path.exists(temp_file):
620
+ try:
621
+ os.remove(temp_file)
622
+ print(f"Removed temporary weights file: {temp_file}")
623
+ except Exception as cleanup_error:
624
+ print(f"Error removing temporary file {temp_file}: {cleanup_error}")
625
+
626
  return JSONResponse(status_code=400, content={"message": f"change gpt weight failed", "Exception": str(e)})
627
 
 
 
 
628
  @APP.get("/set_sovits_weights")
629
  async def set_sovits_weights(weights_path: str = None):
630
+ temp_file = None
631
  try:
632
  if weights_path in ["", None]:
633
  return JSONResponse(status_code=400, content={"message": "sovits weight path is required"})
634
+
635
+ # Process the path (download if it's a URL)
636
+ local_path, is_temp = process_weights_path(weights_path)
637
+ if is_temp:
638
+ temp_file = local_path
639
+
640
+ # Load the weights
641
+ tts_pipeline.init_vits_weights(local_path)
642
+
643
+ # Clean up if it was a temporary file
644
+ # Note: Depending on how init_vits_weights works, you might need to keep the file
645
+ # If the function loads the file into memory, you can delete it right away
646
+ if temp_file and os.path.exists(temp_file):
647
+ os.remove(temp_file)
648
+ print(f"Removed temporary weights file: {temp_file}")
649
+
650
+ return JSONResponse(status_code=200, content={"message": "success"})
651
  except Exception as e:
652
+ # Clean up temp file in case of error
653
+ if temp_file and os.path.exists(temp_file):
654
+ try:
655
+ os.remove(temp_file)
656
+ print(f"Removed temporary weights file: {temp_file}")
657
+ except Exception as cleanup_error:
658
+ print(f"Error removing temporary file {temp_file}: {cleanup_error}")
659
+
660
  return JSONResponse(status_code=400, content={"message": f"change sovits weight failed", "Exception": str(e)})
 
661
 
662
 
663