EchoX / text_to_speech.py
tzzte's picture
Upload 13 files
30320c9 verified
from fairseq.dataclass.configs import FairseqConfig
from fairseq import utils
from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder
from fairseq import checkpoint_utils, options, tasks, utils
from fairseq.distributed import utils as distributed_utils
import torch
import json
from tqdm import tqdm
import random
import soundfile as sf
import numpy as np
import ast
import time
import math
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
from collections import namedtuple
import sys
from argparse import Namespace
import argparse
import sentencepiece as spm
import re
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
def make_batches(lines, cfg, task, max_positions, encode_fn):
def encode_fn_target(x):
return encode_fn(x)
if cfg.generation.constraints:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
for i, line in enumerate(lines):
if "\t" in line:
lines[i], *batch_constraints[i] = line.split("\t")
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
batch_constraints[i] = [
task.target_dictionary.encode_line(
encode_fn_target(constraint),
append_eos=False,
add_if_not_exist=False,
)
for constraint in constraint_list
]
if cfg.generation.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor
),
max_tokens=cfg.dataset.max_tokens,
max_sentences=cfg.dataset.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch["id"]
src_tokens = batch["net_input"]["src_tokens"]
src_lengths = batch["net_input"]["src_lengths"]
constraints = batch.get("constraints", None)
yield Batch(
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
def tokenize(inputs, sp):
text = re.sub(r'[^\w\s]', '', inputs.lower())
inputs = ' '.join(sp.EncodeAsPieces(text))
# print(inputs)
return inputs
def get_t2u_config(model, beam=5):
sys.argv = [
"fairseq-interactive",
"libri_t2u",
"--path", model,
"--gen-subset", "valid",
"--max-len-b", "1024",
"--max-source-positions", "500",
"--max-target-positions", "1024",
"--beam", str(beam),
"--results-path", "decode"
]
parser = options.get_interactive_generation_parser()
args = options.parse_args_and_arch(parser)
# distributed_utils.call_main(convert_namespace_to_omegaconf(args), load_text2units_model)
return convert_namespace_to_omegaconf(args)
def load_text2units_model(cfg: FairseqConfig, device):
if isinstance(cfg, Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
if cfg.interactive.buffer_size < 1:
cfg.interactive.buffer_size = 1
if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None:
cfg.dataset.batch_size = 1
assert (
not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not cfg.dataset.batch_size
or cfg.dataset.batch_size <= cfg.interactive.buffer_size
), "--batch-size cannot be larger than --buffer-size"
# Fix seed for stochastic decoding
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
# Setup task, e.g., translation
task = tasks.setup_task(cfg.task)
# Load ensemble
overrides = ast.literal_eval(cfg.common_eval.model_overrides)
models, _model_args = checkpoint_utils.load_model_ensemble(
utils.split_paths(cfg.common_eval.path),
arg_overrides=overrides,
task=task,
suffix=cfg.checkpoint.checkpoint_suffix,
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
num_shards=cfg.checkpoint.checkpoint_shard_count,
)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
if model is None:
continue
if cfg.common.fp16:
model.half()
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(cfg)
# Initialize generator
generator = task.build_generator(models, cfg.generation)
# Handle tokenization and BPE
tokenizer = task.build_tokenizer(cfg.tokenizer)
bpe = task.build_bpe(cfg.bpe)
return {
"models": models,
"generator": generator,
"tokenizer": tokenizer,
"bpe": bpe,
"task": task,
"src_dict": src_dict,
"tgt_dict": tgt_dict,
"use_cuda": use_cuda
}
def gen_units(model, cfg, inputs):
inputs = [inputs]
models = model['models']
generator = model['generator']
tokenizer = model['tokenizer']
bpe = model['bpe']
task = model['task']
src_dict = model['src_dict']
tgt_dict = model['tgt_dict']
use_cuda = model['use_cuda']
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
align_dict = utils.load_align_dict(cfg.generation.replace_unk)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
start_id = 0
results = []
for batch in make_batches(inputs, cfg, task, max_positions, encode_fn):
print("[INFO_DEBUG]", batch)
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
},
}
translate_start_time = time.time()
translations = task.inference_step(
generator, models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
list_constraints = [[] for _ in range(bsz)]
if cfg.generation.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
},
)
)
# print(results)
units = []
for id_, _, hypos, info in sorted(results, key=lambda x: x[0]):
print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
# Process top predictions
for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str="",
alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=cfg.common_eval.post_process,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
units.append(list(map(int, hypo_str.split(' '))))
return units
def get_vocoder_config(vocoder, config):
args = argparse.Namespace(
vocoder=vocoder,
vocoder_cfg=config,
dur_prediction=True,
speaker_id=1,
cpu=False
)
return args
def load_units_vocoder(args, device):
with open(args.vocoder_cfg) as f:
vocoder_cfg = json.load(f)
vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).to(device)
multispkr = vocoder.model.multispkr
if multispkr:
num_speakers = vocoder_cfg.get(
"num_speakers", 200
) # following the default in codehifigan to set to 200
assert (
args.speaker_id < num_speakers
), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}"
return vocoder, num_speakers if multispkr else 1, 'cuda' in device
def gen_wav(vocoder, args, data, device):
vocoder, num_speakers, use_cuda = vocoder
res = []
for i, d in enumerate(data): # tqdm is removed for cleaner streaming
x = {
"code": torch.LongTensor(d).view(1, -1).to(device),
}
suffix = ""
multispkr = vocoder.model.multispkr
if multispkr:
spk = (
random.randint(0, num_speakers - 1)
if args.speaker_id == -1
else args.speaker_id
)
suffix = f"_spk{spk}"
x["spkr"] = torch.LongTensor([spk]).view(1, 1)
x = utils.move_to_cuda(x) if use_cuda else x
wav = vocoder(x, args.dur_prediction).detach().cpu().numpy()
res.append(wav)
return res[0]