Update app.py
Browse files
app.py
CHANGED
@@ -23,12 +23,15 @@ def infer(prompt, progress=gr.Progress(track_tqdm=True)):
|
|
23 |
|
24 |
return f"{prompt}.wav"
|
25 |
|
26 |
-
def infer_img2img(prompt, audio_path):
|
27 |
|
28 |
pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
|
29 |
dtype = torch.float16
|
30 |
device = "cuda"
|
31 |
|
|
|
|
|
|
|
32 |
vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder="vocoder")
|
33 |
vocoder = vocoder.to(device=device, dtype=dtype)
|
34 |
|
|
|
23 |
|
24 |
return f"{prompt}.wav"
|
25 |
|
26 |
+
def infer_img2img(prompt, audio_path, progress=gr.Progress(track_tqdm=True)):
|
27 |
|
28 |
pretrained_model_name_or_path = "auffusion/auffusion-full-no-adapter"
|
29 |
dtype = torch.float16
|
30 |
device = "cuda"
|
31 |
|
32 |
+
if not os.path.isdir(pretrained_model_name_or_path):
|
33 |
+
pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path)
|
34 |
+
|
35 |
vocoder = Generator.from_pretrained(pretrained_model_name_or_path, subfolder="vocoder")
|
36 |
vocoder = vocoder.to(device=device, dtype=dtype)
|
37 |
|