import hashlib import math import os import shutil import subprocess import tempfile import threading import uuid import numpy as np import regex as re import soundfile as sf import torch import torchaudio from huggingface_hub import hf_hub_download from pathlib import Path from pprint import pprint from lib import * from lib.classes.tts_engines.common.utils import unload_tts, append_sentence2vtt from lib.classes.tts_engines.common.audio_filters import detect_gender, trim_audio, normalize_audio, is_audio_data_valid #import logging #logging.basicConfig(level=logging.DEBUG) lock = threading.Lock() class Coqui: def __init__(self, session): try: self.session = session self.cache_dir = tts_dir self.speakers_path = None self.tts_key = f"{self.session['tts_engine']}-{self.session['fine_tuned']}" self.tts_vc_key = default_vc_model.rsplit('/', 1)[-1] self.is_bf16 = True if self.session['device'] == 'cuda' and torch.cuda.is_bf16_supported() == True else False self.npz_path = None self.npz_data = None self.sentences_total_time = 0.0 self.sentence_idx = 1 self.params = {TTS_ENGINES['NEW_TTS']: {}} self.params[self.session['tts_engine']]['samplerate'] = models[self.session['tts_engine']][self.session['fine_tuned']]['samplerate'] self.vtt_path = os.path.join(self.session['process_dir'], os.path.splitext(self.session['final_name'])[0] + '.vtt') self.resampler_cache = {} self.audio_segments = [] self._build() except Exception as e: error = f'__init__() error: {e}' print(error) return None def _build(self): try: tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False) if not tts: if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']: if self.session['custom_model'] is not None: msg = f"{self.session['tts_engine']} custom model not implemented yet!" print(msg) return False else: model_path = models[self.session['tts_engine']][self.session['fine_tuned']]['repo'] tts = self._load_api(self.tts_key, model_path, self.session['device']) return (loaded_tts.get(self.tts_key) or {}).get('engine', False) except Exception as e: error = f'build() error: {e}' print(error) return False def _load_api(self, key, model_path, device): global lock try: if key in loaded_tts.keys(): return loaded_tts[key]['engine'] unload_tts(device, [self.tts_key, self.tts_vc_key]) with lock: tts = NEW_TTS(model_path) if tts if device == 'cuda': NEW_TTS.WITH_CUDA else: NEW_TTS.WITHOUT_CUDA loaded_tts[key] = {"engine": tts, "config": None} msg = f'{model_path} Loaded!' print(msg) return tts else: error = 'TTS engine could not be created!' print(error) except Exception as e: error = f'_load_api() error: {e}' print(error) return False def _load_checkpoint(self, **kwargs): global lock try: key = kwargs.get('key') if key in loaded_tts.keys(): return loaded_tts[key]['engine'] tts_engine = kwargs.get('tts_engine') device = kwargs.get('device') unload_tts(device, [self.tts_key]) with lock: checkpoint_dir = kwargs.get('checkpoint_dir') NEW_TTS.LOAD_CHECKPOINT( config, checkpoint_dir=checkpoint_dir, eval=True ) if tts: if device == 'cuda': NEW_TTS.WITH_CUDA else: NEW_TTS.WITHOUT_CUDA loaded_tts[key] = {"engine": tts, "config": config} msg = f'{tts_engine} Loaded!' print(msg) return tts else: error = 'TTS engine could not be created!' print(error) except Exception as e: error = f'_load_checkpoint() error: {e}' return False def _tensor_type(self, audio_data): if isinstance(audio_data, torch.Tensor): return audio_data elif isinstance(audio_data, np.ndarray): return torch.from_numpy(audio_data).float() elif isinstance(audio_data, list): return torch.tensor(audio_data, dtype=torch.float32) else: raise TypeError(f"Unsupported type for audio_data: {type(audio_data)}") def _get_resampler(self, orig_sr, target_sr): key = (orig_sr, target_sr) if key not in self.resampler_cache: self.resampler_cache[key] = torchaudio.transforms.Resample( orig_freq=orig_sr, new_freq=target_sr ) return self.resampler_cache[key] def _resample_wav(self, wav_path, expected_sr): waveform, orig_sr = torchaudio.load(wav_path) if orig_sr == expected_sr and waveform.size(0) == 1: return wav_path if waveform.size(0) > 1: waveform = waveform.mean(dim=0, keepdim=True) if orig_sr != expected_sr: resampler = self._get_resampler(orig_sr, expected_sr) waveform = resampler(waveform) wav_tensor = waveform.squeeze(0) wav_numpy = wav_tensor.cpu().numpy() tmp_fh = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_path = tmp_fh.name tmp_fh.close() sf.write(tmp_path, wav_numpy, expected_sr, subtype="PCM_16") return tmp_path def convert(self, sentence_number, sentence): global xtts_builtin_speakers_list try: speaker = None audio_data = False trim_audio_buffer = 0.004 settings = self.params[self.session['tts_engine']] final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_number}.{default_audio_proc_format}') sentence = sentence.strip() settings['voice_path'] = ( self.session['voice'] if self.session['voice'] is not None else os.path.join(self.session['custom_model_dir'], self.session['tts_engine'], self.session['custom_model'], 'ref.wav') if self.session['custom_model'] is not None else models[self.session['tts_engine']][self.session['fine_tuned']]['voice'] ) if settings['voice_path'] is not None: speaker = re.sub(r'\.wav$', '', os.path.basename(settings['voice_path'])) tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False) if tts: if sentence[-1].isalnum(): sentence = f'{sentence} —' if sentence == TTS_SML['break']: break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) # 0.4 to 0.7 seconds self.audio_segments.append(break_tensor.clone()) return True elif sentence == TTS_SML['pause']: pause_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(1.0, 1.8) * 100) / 100))) # 1.0 to 1.8 seconds self.audio_segments.append(pause_tensor.clone()) return True else: if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']: audio_sentence = NEW_TTS.CONVERT() # audio_sentence must be torch.Tensor or (list, tuple) or np.ndarray if is_audio_data_valid(audio_sentence): sourceTensor = self._tensor_type(audio_sentence) audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu() if sentence[-1].isalnum() or sentence[-1] == '—': audio_tensor = trim_audio(audio_tensor.squeeze(), settings['samplerate'], 0.003, trim_audio_buffer).unsqueeze(0) self.audio_segments.append(audio_tensor) if not re.search(r'\w$', sentence, flags=re.UNICODE): break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) self.audio_segments.append(break_tensor.clone()) if self.audio_segments: audio_tensor = torch.cat(self.audio_segments, dim=-1) start_time = self.sentences_total_time duration = audio_tensor.shape[-1] / settings['samplerate'] end_time = start_time + duration self.sentences_total_time = end_time sentence_obj = { "start": start_time, "end": end_time, "text": sentence, "resume_check": self.sentence_idx } self.sentence_idx = append_sentence2vtt(sentence_obj, self.vtt_path) if self.sentence_idx: torchaudio.save(final_sentence_file, audio_tensor, settings['samplerate'], format=default_audio_proc_format) del audio_tensor self.audio_segments = [] if os.path.exists(final_sentence_file): return True else: error = f"Cannot create {final_sentence_file}" print(error) else: error = f"convert() error: {self.session['tts_engine']} is None" print(error) except Exception as e: error = f'Coquit.convert(): {e}' raise ValueError(e) return False