Spaces:
Running
on
Zero
Running
on
Zero
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/configs/E2TTS_Base_train.yaml
CHANGED
|
@@ -41,4 +41,4 @@ ckpts:
|
|
| 41 |
logger: wandb # wandb | tensorboard | None
|
| 42 |
save_per_updates: 50000 # save checkpoint per steps
|
| 43 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 44 |
-
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
|
| 41 |
logger: wandb # wandb | tensorboard | None
|
| 42 |
save_per_updates: 50000 # save checkpoint per steps
|
| 43 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 44 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/E2TTS_Small_train.yaml
CHANGED
|
@@ -41,4 +41,4 @@ ckpts:
|
|
| 41 |
logger: wandb # wandb | tensorboard | None
|
| 42 |
save_per_updates: 50000 # save checkpoint per steps
|
| 43 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 44 |
-
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
|
| 41 |
logger: wandb # wandb | tensorboard | None
|
| 42 |
save_per_updates: 50000 # save checkpoint per steps
|
| 43 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 44 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Base_train.yaml
CHANGED
|
@@ -43,4 +43,4 @@ ckpts:
|
|
| 43 |
logger: wandb # wandb | tensorboard | None
|
| 44 |
save_per_updates: 50000 # save checkpoint per steps
|
| 45 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 46 |
-
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
|
| 43 |
logger: wandb # wandb | tensorboard | None
|
| 44 |
save_per_updates: 50000 # save checkpoint per steps
|
| 45 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 46 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/configs/F5TTS_Small_train.yaml
CHANGED
|
@@ -43,4 +43,4 @@ ckpts:
|
|
| 43 |
logger: wandb # wandb | tensorboard | None
|
| 44 |
save_per_updates: 50000 # save checkpoint per steps
|
| 45 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 46 |
-
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
|
|
|
| 43 |
logger: wandb # wandb | tensorboard | None
|
| 44 |
save_per_updates: 50000 # save checkpoint per steps
|
| 45 |
last_per_steps: 5000 # save last checkpoint per steps
|
| 46 |
+
save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
|
src/f5_tts/infer/utils_infer.py
CHANGED
|
@@ -138,7 +138,11 @@ asr_pipe = None
|
|
| 138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
dtype = (
|
| 141 |
-
torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
)
|
| 143 |
global asr_pipe
|
| 144 |
asr_pipe = pipeline(
|
|
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
|
|
| 171 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
| 172 |
if dtype is None:
|
| 173 |
dtype = (
|
| 174 |
-
torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
model = model.to(dtype)
|
| 177 |
|
|
|
|
| 138 |
def initialize_asr_pipeline(device: str = device, dtype=None):
|
| 139 |
if dtype is None:
|
| 140 |
dtype = (
|
| 141 |
+
torch.float16
|
| 142 |
+
if "cuda" in device
|
| 143 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
| 144 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 145 |
+
else torch.float32
|
| 146 |
)
|
| 147 |
global asr_pipe
|
| 148 |
asr_pipe = pipeline(
|
|
|
|
| 175 |
def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
|
| 176 |
if dtype is None:
|
| 177 |
dtype = (
|
| 178 |
+
torch.float16
|
| 179 |
+
if "cuda" in device
|
| 180 |
+
and torch.cuda.get_device_properties(device).major >= 6
|
| 181 |
+
and not torch.cuda.get_device_name().endswith("[ZLUDA]")
|
| 182 |
+
else torch.float32
|
| 183 |
)
|
| 184 |
model = model.to(dtype)
|
| 185 |
|