Spaces:
Running
on
Zero
Running
on
Zero
#!/usr/bin/env python3 | |
# Copyright 2025 Xiaomi Corp. (authors: Han Zhu) | |
# | |
# See ../../../../LICENSE for clarification regarding multiple authors | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" | |
This script generates speech with our pre-trained ZipVoice or | |
ZipVoice-Distill models. If no local model is specified, | |
Required files will be automatically downloaded from HuggingFace. | |
Usage: | |
Note: If you having trouble connecting to HuggingFace, | |
try switching endpoint to mirror site: | |
export HF_ENDPOINT=https://hf-mirror.com | |
(1) Inference of a single sentence: | |
python3 -m zipvoice.bin.infer_zipvoice \ | |
--model-name "zipvoice" \ | |
--prompt-wav prompt.wav \ | |
--prompt-text "I am a prompt." \ | |
--text "I am a sentence." \ | |
--res-wav-path result.wav | |
(2) Inference of a list of sentences: | |
python3 -m zipvoice.bin.infer_zipvoice \ | |
--model-name "zipvoice" \ | |
--test-list test.tsv \ | |
--res-dir results | |
`--model-name` can be `zipvoice` or `zipvoice_distill`, | |
which are the models before and after distillation, respectively. | |
Each line of `test.tsv` is in the format of | |
`{wav_name}\t{prompt_transcription}\t{prompt_wav}\t{text}`. | |
""" | |
import argparse | |
import datetime as dt | |
import json | |
import os | |
from typing import Optional | |
import numpy as np | |
import safetensors.torch | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
from lhotse.utils import fix_random_seed | |
from vocos import Vocos | |
from zipvoice.models.zipvoice import ZipVoice | |
from zipvoice.models.zipvoice_distill import ZipVoiceDistill | |
from zipvoice.tokenizer.tokenizer import ( | |
EmiliaTokenizer, | |
EspeakTokenizer, | |
LibriTTSTokenizer, | |
SimpleTokenizer, | |
) | |
from zipvoice.utils.checkpoint import load_checkpoint | |
from zipvoice.utils.common import AttributeDict | |
from zipvoice.utils.feature import VocosFbank | |
HUGGINGFACE_REPO = "k2-fsa/ZipVoice" | |
PRETRAINED_MODEL = { | |
"zipvoice": "zipvoice/model.pt", | |
"zipvoice_distill": "zipvoice_distill/model.pt", | |
} | |
TOKEN_FILE = { | |
"zipvoice": "zipvoice/tokens.txt", | |
"zipvoice_distill": "zipvoice_distill/tokens.txt", | |
} | |
MODEL_CONFIG = { | |
"zipvoice": "zipvoice/zipvoice_base.json", | |
"zipvoice_distill": "zipvoice_distill/zipvoice_base.json", | |
} | |
# torch.set_num_threads(1) | |
# torch.set_num_interop_threads(1) | |
def get_vocoder(vocos_local_path: Optional[str] = None): | |
if vocos_local_path: | |
vocoder = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") | |
state_dict = torch.load( | |
f"{vocos_local_path}/pytorch_model.bin", | |
weights_only=True, | |
map_location="cpu", | |
) | |
vocoder.load_state_dict(state_dict) | |
else: | |
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz") | |
return vocoder | |
def generate_sentence( | |
prompt_text: str, | |
prompt_wav: str, | |
text: str, | |
model: torch.nn.Module, | |
vocoder: torch.nn.Module, | |
tokenizer: EmiliaTokenizer, | |
feature_extractor: VocosFbank, | |
device: torch.device, | |
num_step: int = 32, | |
guidance_scale: float = 1.0, | |
speed: float = 1.0, | |
t_shift: float = 0.5, | |
target_rms: float = 0.1, | |
feat_scale: float = 0.1, | |
sampling_rate: int = 24000, | |
): | |
""" | |
Generate waveform of a text based on a given prompt | |
waveform and its transcription. | |
Args: | |
save_path (str): Path to save the generated wav. | |
prompt_text (str): Transcription of the prompt wav. | |
prompt_wav (str): Path to the prompt wav file. | |
text (str): Text to be synthesized into a waveform. | |
model (torch.nn.Module): The model used for generation. | |
vocoder (torch.nn.Module): The vocoder used to convert features to waveforms. | |
tokenizer (EmiliaTokenizer): The tokenizer used to convert text to tokens. | |
feature_extractor (VocosFbank): The feature extractor used to | |
extract acoustic features. | |
device (torch.device): The device on which computations are performed. | |
num_step (int, optional): Number of steps for decoding. Defaults to 16. | |
guidance_scale (float, optional): Scale for classifier-free guidance. | |
Defaults to 1.0. | |
speed (float, optional): Speed control. Defaults to 1.0. | |
t_shift (float, optional): Time shift. Defaults to 0.5. | |
target_rms (float, optional): Target RMS for waveform normalization. | |
Defaults to 0.1. | |
feat_scale (float, optional): Scale for features. | |
Defaults to 0.1. | |
sampling_rate (int, optional): Sampling rate for the waveform. | |
Defaults to 24000. | |
Returns: | |
metrics (dict): Dictionary containing time and real-time | |
factor metrics for processing. | |
""" | |
# Convert text to tokens | |
tokens = tokenizer.texts_to_token_ids([text]) | |
prompt_tokens = tokenizer.texts_to_token_ids([prompt_text]) | |
# Load and preprocess prompt wav | |
prompt_wav, prompt_sampling_rate = torchaudio.load(prompt_wav) | |
if prompt_sampling_rate != sampling_rate: | |
resampler = torchaudio.transforms.Resample( | |
orig_freq=prompt_sampling_rate, new_freq=sampling_rate | |
) | |
prompt_wav = resampler(prompt_wav) | |
prompt_rms = torch.sqrt(torch.mean(torch.square(prompt_wav))) | |
if prompt_rms < target_rms: | |
prompt_wav = prompt_wav * target_rms / prompt_rms | |
# Extract features from prompt wav | |
prompt_features = feature_extractor.extract( | |
prompt_wav, sampling_rate=sampling_rate | |
).to(device) | |
prompt_features = prompt_features.unsqueeze(0) * feat_scale | |
prompt_features_lens = torch.tensor([prompt_features.size(1)], device=device) | |
# Start timing | |
start_t = dt.datetime.now() | |
# Generate features | |
( | |
pred_features, | |
pred_features_lens, | |
pred_prompt_features, | |
pred_prompt_features_lens, | |
) = model.sample( | |
tokens=tokens, | |
prompt_tokens=prompt_tokens, | |
prompt_features=prompt_features, | |
prompt_features_lens=prompt_features_lens, | |
speed=speed, | |
t_shift=t_shift, | |
duration="predict", | |
num_step=num_step, | |
guidance_scale=guidance_scale, | |
) | |
# Postprocess predicted features | |
pred_features = pred_features.permute(0, 2, 1) / feat_scale # (B, C, T) | |
# Start vocoder processing | |
start_vocoder_t = dt.datetime.now() | |
wav = vocoder.decode(pred_features).squeeze(1).clamp(-1, 1) | |
# Calculate processing times and real-time factors | |
t = (dt.datetime.now() - start_t).total_seconds() | |
t_no_vocoder = (start_vocoder_t - start_t).total_seconds() | |
t_vocoder = (dt.datetime.now() - start_vocoder_t).total_seconds() | |
wav_seconds = wav.shape[-1] / sampling_rate | |
rtf = t / wav_seconds | |
rtf_no_vocoder = t_no_vocoder / wav_seconds | |
rtf_vocoder = t_vocoder / wav_seconds | |
# metrics = { | |
# "t": t, | |
# "t_no_vocoder": t_no_vocoder, | |
# "t_vocoder": t_vocoder, | |
# "wav_seconds": wav_seconds, | |
# "rtf": rtf, | |
# "rtf_no_vocoder": rtf_no_vocoder, | |
# "rtf_vocoder": rtf_vocoder, | |
# } | |
# Adjust wav volume if necessary | |
if prompt_rms < target_rms: | |
wav = wav * prompt_rms / target_rms | |
# torchaudio.save(save_path, wav.cpu(), sample_rate=sampling_rate) | |
# return metrics | |
return wav.cpu() | |
model_defaults = { | |
"zipvoice": { | |
"num_step": 32, | |
"guidance_scale": 1.0, | |
}, | |
"zipvoice_distill": { | |
"num_step": 8, | |
"guidance_scale": 3.0, | |
}, | |
} | |
device = torch.device("cuda", 0) | |
print("Loading model...") | |
model_config = "config.json" | |
with open(model_config, "r") as f: | |
model_config = json.load(f) | |
token_file = "tokens.txt" | |
tokenizer = EspeakTokenizer(token_file=token_file, lang="vi") | |
tokenizer_config = {"vocab_size": tokenizer.vocab_size, "pad_id": tokenizer.pad_id} | |
model_ckpt = "iter-525000-avg-2.pt" | |
model = ZipVoice( | |
**model_config["model"], | |
**tokenizer_config, | |
) | |
load_checkpoint(filename=model_ckpt, model=model, strict=True) | |
model = model.to(device) | |
model.eval() | |
vocoder = get_vocoder(None) | |
vocoder = vocoder.to(device) | |
vocoder.eval() | |
if model_config["feature"]["type"] == "vocos": | |
feature_extractor = VocosFbank() | |
else: | |
raise NotImplementedError( | |
f"Unsupported feature type: {model_config['feature']['type']}" | |
) | |
sampling_rate = model_config["feature"]["sampling_rate"] | |
# generate_sentence( | |
# save_path=res_wav_path, | |
# prompt_text=prompt_text, | |
# prompt_wav=prompt_wav, | |
# text=text, | |
# model=model, | |
# vocoder=vocoder, | |
# tokenizer=tokenizer, | |
# feature_extractor=feature_extractor, | |
# device=device, | |
# num_step=16, | |
# guidance_scale=1.0, | |
# speed=speed, | |
# t_shift=0.5, | |
# target_rms=0.1, | |
# feat_scale=0.1, | |
# sampling_rate=sampling_rate, | |
# ) | |
# print("Done") |