from fairseq.dataclass.configs import FairseqConfig from fairseq import utils from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder from fairseq import checkpoint_utils, options, tasks, utils from fairseq.distributed import utils as distributed_utils import torch import json from tqdm import tqdm import random import soundfile as sf import numpy as np import ast import time import math from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.token_generation_constraints import pack_constraints, unpack_constraints from fairseq_cli.generate import get_symbols_to_strip_from_output from collections import namedtuple import sys from argparse import Namespace import argparse import sentencepiece as spm import re Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints") Translation = namedtuple("Translation", "src_str hypos pos_scores alignments") def make_batches(lines, cfg, task, max_positions, encode_fn): def encode_fn_target(x): return encode_fn(x) if cfg.generation.constraints: # Strip (tab-delimited) contraints, if present, from input lines, # store them in batch_constraints batch_constraints = [list() for _ in lines] for i, line in enumerate(lines): if "\t" in line: lines[i], *batch_constraints[i] = line.split("\t") # Convert each List[str] to List[Tensor] for i, constraint_list in enumerate(batch_constraints): batch_constraints[i] = [ task.target_dictionary.encode_line( encode_fn_target(constraint), append_eos=False, add_if_not_exist=False, ) for constraint in constraint_list ] if cfg.generation.constraints: constraints_tensor = pack_constraints(batch_constraints) else: constraints_tensor = None tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn) itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference( tokens, lengths, constraints=constraints_tensor ), max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=max_positions, ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, ).next_epoch_itr(shuffle=False) for batch in itr: ids = batch["id"] src_tokens = batch["net_input"]["src_tokens"] src_lengths = batch["net_input"]["src_lengths"] constraints = batch.get("constraints", None) yield Batch( ids=ids, src_tokens=src_tokens, src_lengths=src_lengths, constraints=constraints, ) def tokenize(inputs, sp): text = re.sub(r'[^\w\s]', '', inputs.lower()) inputs = ' '.join(sp.EncodeAsPieces(text)) # print(inputs) return inputs def get_t2u_config(model, beam=5): sys.argv = [ "fairseq-interactive", "libri_t2u", "--path", model, "--gen-subset", "valid", "--max-len-b", "1024", "--max-source-positions", "500", "--max-target-positions", "1024", "--beam", str(beam), "--results-path", "decode" ] parser = options.get_interactive_generation_parser() args = options.parse_args_and_arch(parser) # distributed_utils.call_main(convert_namespace_to_omegaconf(args), load_text2units_model) return convert_namespace_to_omegaconf(args) def load_text2units_model(cfg: FairseqConfig, device): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if cfg.interactive.buffer_size < 1: cfg.interactive.buffer_size = 1 if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.batch_size = 1 assert ( not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert ( not cfg.dataset.batch_size or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # Load ensemble overrides = ast.literal_eval(cfg.common_eval.model_overrides) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Initialize generator generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = task.build_tokenizer(cfg.tokenizer) bpe = task.build_bpe(cfg.bpe) return { "models": models, "generator": generator, "tokenizer": tokenizer, "bpe": bpe, "task": task, "src_dict": src_dict, "tgt_dict": tgt_dict, "use_cuda": use_cuda } def gen_units(model, cfg, inputs): inputs = [inputs] models = model['models'] generator = model['generator'] tokenizer = model['tokenizer'] bpe = model['bpe'] task = model['task'] src_dict = model['src_dict'] tgt_dict = model['tgt_dict'] use_cuda = model['use_cuda'] def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ) start_id = 0 results = [] for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): print("[INFO_DEBUG]", batch) bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translate_start_time = time.time() translations = task.inference_step( generator, models, sample, constraints=constraints ) translate_time = time.time() - translate_start_time list_constraints = [[] for _ in range(bsz)] if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append( ( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, ) ) # print(results) units = [] for id_, _, hypos, info in sorted(results, key=lambda x: x[0]): print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) # Process top predictions for hypo in hypos[: min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str="", alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) units.append(list(map(int, hypo_str.split(' ')))) return units def get_vocoder_config(vocoder, config): args = argparse.Namespace( vocoder=vocoder, vocoder_cfg=config, dur_prediction=True, speaker_id=1, cpu=False ) return args def load_units_vocoder(args, device): with open(args.vocoder_cfg) as f: vocoder_cfg = json.load(f) vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).to(device) multispkr = vocoder.model.multispkr if multispkr: num_speakers = vocoder_cfg.get( "num_speakers", 200 ) # following the default in codehifigan to set to 200 assert ( args.speaker_id < num_speakers ), f"invalid --speaker-id ({args.speaker_id}) with total #speakers = {num_speakers}" return vocoder, num_speakers if multispkr else 1, 'cuda' in device def gen_wav(vocoder, args, data, device): vocoder, num_speakers, use_cuda = vocoder res = [] for i, d in enumerate(data): # tqdm is removed for cleaner streaming x = { "code": torch.LongTensor(d).view(1, -1).to(device), } suffix = "" multispkr = vocoder.model.multispkr if multispkr: spk = ( random.randint(0, num_speakers - 1) if args.speaker_id == -1 else args.speaker_id ) suffix = f"_spk{spk}" x["spkr"] = torch.LongTensor([spk]).view(1, 1) x = utils.move_to_cuda(x) if use_cuda else x wav = vocoder(x, args.dur_prediction).detach().cpu().numpy() res.append(wav) return res[0]