DSTK / tts_example.py
gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
raw
history blame
5.83 kB
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Xiao Chen)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import librosa
import logging
import soundfile as sf
import sys
from pathlib import Path
sub_modules = ["", "thirdparty/G2P", "semantic_tokenizer/f40ms", "text2token", "semantic_detokenizer"]
for sub in sub_modules:
sys.path.append(str((Path(__file__).parent / sub).absolute()))
from semantic_tokenizer.f40ms.simple_tokenizer_infer import SpeechTokenizer, TOKENIZER_CFG_NAME
from text2token.simple_infer import Text2TokenGenerator
from semantic_detokenizer.chunk_infer import SpeechDetokenizer
class TTSPipeline:
def __init__(
self,
detok_vocoder: str,
tokenizer_cfg_name: str = TOKENIZER_CFG_NAME,
tokenizer_cfg_path: str = str(
(Path(__file__).parent / "semantic_tokenizer/f40ms/config").absolute()
),
tokenizer_ckpt: str = str(
(
Path(__file__).parent / "semantic_tokenizer/f40ms/ckpt/model.pt"
).absolute()
),
max_seg_len: int = 0,
detok_model_cfg: str = str(
(Path(__file__).parent / "semantic_detokenizer/ckpt/model.yaml").absolute()
),
detok_ckpt: str = str(
(Path(__file__).parent / "semantic_detokenizer/ckpt/model.pt").absolute()
),
detok_vocab: str = str(
(
Path(__file__).parent / "semantic_detokenizer/ckpt/vocab_4096.txt"
).absolute()
),
):
self.tokenizer_cfg_name = tokenizer_cfg_name
self.tokenizer = SpeechTokenizer(
ckpt_path=tokenizer_ckpt,
cfg_path=tokenizer_cfg_path,
cfg_name=self.tokenizer_cfg_name,
)
self.t2u_max_seg_len = max_seg_len
self.t2u = Text2TokenGenerator()
self.device = "cuda:0"
self.detoker = SpeechDetokenizer(
vocoder_path=detok_vocoder,
model_cfg=detok_model_cfg,
ckpt_file=detok_ckpt,
vocab_file=detok_vocab,
device=self.device,
)
self.token_chunk_len = 75
self.chunk_cond_proportion = 0.3
self.chunk_look_ahead = 10
self.max_ref_duration = 4.5
self.ref_audio_cut_from_head = False
def synthesize(self, ref_wav, input_text):
ref_wavs_list = []
raw_wav, sr = librosa.load(ref_wav, sr=16000)
ref_wavs_list.append(raw_wav)
token_list, token_info_list = self.tokenizer.extract(
ref_wavs_list
)
ref_token_list = token_info_list[0]["reduced_unit_sequence"]
logging.info("tokens for ref wav: %s are [%s]" % (ref_wav, ref_token_list))
phones = self.t2u.text2phone(input_text.strip())
logging.info("phonemes of input text: %s are [%s]" % (input_text, phones))
speech_tokens_info = self.t2u.generate_for_long_input_text(
[phones], max_segment_len=self.t2u_max_seg_len
)
generated_wave, target_sample_rate = self.detoker.chunk_generate(
ref_wav,
ref_token_list.split(),
speech_tokens_info[0][0],
self.token_chunk_len,
self.chunk_cond_proportion,
self.chunk_look_ahead,
self.max_ref_duration,
self.ref_audio_cut_from_head,
)
if generated_wave is None:
logging.info("generation FAILED")
return None, None
return generated_wave, target_sample_rate
def main(args):
# initialize
tts = TTSPipeline(
detok_vocoder=args.detok_vocoder,
max_seg_len=args.max_seg_len,
)
generated_wave, target_sample_rate = tts.synthesize(args.ref_wav, args.input_text)
with open(args.output_wav, "wb") as f:
sf.write(f.name, generated_wave, target_sample_rate)
logging.info(f"write output to: {f.name}")
logging.info("Finished")
return
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokenizer-ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--tokenizer-cfg-path",
required=False,
default="semantic_tokenizer/f40ms/config",
help="path to config",
)
parser.add_argument(
"--detok-ckpt",
required=False,
help="path to ckpt",
)
parser.add_argument(
"--detok-model-cfg",
required=False,
help="path to model_cfg",
)
parser.add_argument(
"--detok-vocab",
required=False,
help="path to vocab",
)
parser.add_argument(
"--detok-vocoder",
required=True,
help="path to vocoder",
)
parser.add_argument(
"--ref-wav",
required=True,
help="path to ref wav",
)
parser.add_argument(
"--max-seg-len",
required=False,
default=0,
type=int,
help="max segment length",
)
parser.add_argument(
"--output-wav",
required=True,
help="path to output synthesized wav",
)
parser.add_argument(
"--input-text",
required=True,
help="input text to synthesize",
)
args = parser.parse_args()
main(args)