priteshmistry's picture
Upload 22 files
aa7ea23 verified
raw
history blame
10.6 kB
import hashlib
import math
import os
import shutil
import subprocess
import tempfile
import threading
import uuid
import numpy as np
import regex as re
import soundfile as sf
import torch
import torchaudio
from huggingface_hub import hf_hub_download
from pathlib import Path
from pprint import pprint
from lib import *
from lib.classes.tts_engines.common.utils import unload_tts, append_sentence2vtt
from lib.classes.tts_engines.common.audio_filters import detect_gender, trim_audio, normalize_audio, is_audio_data_valid
#import logging
#logging.basicConfig(level=logging.DEBUG)
lock = threading.Lock()
class Coqui:
def __init__(self, session):
try:
self.session = session
self.cache_dir = tts_dir
self.speakers_path = None
self.tts_key = f"{self.session['tts_engine']}-{self.session['fine_tuned']}"
self.tts_vc_key = default_vc_model.rsplit('/', 1)[-1]
self.is_bf16 = True if self.session['device'] == 'cuda' and torch.cuda.is_bf16_supported() == True else False
self.npz_path = None
self.npz_data = None
self.sentences_total_time = 0.0
self.sentence_idx = 1
self.params = {TTS_ENGINES['NEW_TTS']: {}}
self.params[self.session['tts_engine']]['samplerate'] = models[self.session['tts_engine']][self.session['fine_tuned']]['samplerate']
self.vtt_path = os.path.join(self.session['process_dir'], os.path.splitext(self.session['final_name'])[0] + '.vtt')
self.resampler_cache = {}
self.audio_segments = []
self._build()
except Exception as e:
error = f'__init__() error: {e}'
print(error)
return None
def _build(self):
try:
tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False)
if not tts:
if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']:
if self.session['custom_model'] is not None:
msg = f"{self.session['tts_engine']} custom model not implemented yet!"
print(msg)
return False
else:
model_path = models[self.session['tts_engine']][self.session['fine_tuned']]['repo']
tts = self._load_api(self.tts_key, model_path, self.session['device'])
return (loaded_tts.get(self.tts_key) or {}).get('engine', False)
except Exception as e:
error = f'build() error: {e}'
print(error)
return False
def _load_api(self, key, model_path, device):
global lock
try:
if key in loaded_tts.keys():
return loaded_tts[key]['engine']
unload_tts(device, [self.tts_key, self.tts_vc_key])
with lock:
tts = NEW_TTS(model_path)
if tts
if device == 'cuda':
NEW_TTS.WITH_CUDA
else:
NEW_TTS.WITHOUT_CUDA
loaded_tts[key] = {"engine": tts, "config": None}
msg = f'{model_path} Loaded!'
print(msg)
return tts
else:
error = 'TTS engine could not be created!'
print(error)
except Exception as e:
error = f'_load_api() error: {e}'
print(error)
return False
def _load_checkpoint(self, **kwargs):
global lock
try:
key = kwargs.get('key')
if key in loaded_tts.keys():
return loaded_tts[key]['engine']
tts_engine = kwargs.get('tts_engine')
device = kwargs.get('device')
unload_tts(device, [self.tts_key])
with lock:
checkpoint_dir = kwargs.get('checkpoint_dir')
NEW_TTS.LOAD_CHECKPOINT(
config,
checkpoint_dir=checkpoint_dir,
eval=True
)
if tts:
if device == 'cuda':
NEW_TTS.WITH_CUDA
else:
NEW_TTS.WITHOUT_CUDA
loaded_tts[key] = {"engine": tts, "config": config}
msg = f'{tts_engine} Loaded!'
print(msg)
return tts
else:
error = 'TTS engine could not be created!'
print(error)
except Exception as e:
error = f'_load_checkpoint() error: {e}'
return False
def _tensor_type(self, audio_data):
if isinstance(audio_data, torch.Tensor):
return audio_data
elif isinstance(audio_data, np.ndarray):
return torch.from_numpy(audio_data).float()
elif isinstance(audio_data, list):
return torch.tensor(audio_data, dtype=torch.float32)
else:
raise TypeError(f"Unsupported type for audio_data: {type(audio_data)}")
def _get_resampler(self, orig_sr, target_sr):
key = (orig_sr, target_sr)
if key not in self.resampler_cache:
self.resampler_cache[key] = torchaudio.transforms.Resample(
orig_freq=orig_sr, new_freq=target_sr
)
return self.resampler_cache[key]
def _resample_wav(self, wav_path, expected_sr):
waveform, orig_sr = torchaudio.load(wav_path)
if orig_sr == expected_sr and waveform.size(0) == 1:
return wav_path
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
if orig_sr != expected_sr:
resampler = self._get_resampler(orig_sr, expected_sr)
waveform = resampler(waveform)
wav_tensor = waveform.squeeze(0)
wav_numpy = wav_tensor.cpu().numpy()
tmp_fh = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_path = tmp_fh.name
tmp_fh.close()
sf.write(tmp_path, wav_numpy, expected_sr, subtype="PCM_16")
return tmp_path
def convert(self, sentence_number, sentence):
global xtts_builtin_speakers_list
try:
speaker = None
audio_data = False
trim_audio_buffer = 0.004
settings = self.params[self.session['tts_engine']]
final_sentence_file = os.path.join(self.session['chapters_dir_sentences'], f'{sentence_number}.{default_audio_proc_format}')
sentence = sentence.strip()
settings['voice_path'] = (
self.session['voice'] if self.session['voice'] is not None
else os.path.join(self.session['custom_model_dir'], self.session['tts_engine'], self.session['custom_model'], 'ref.wav') if self.session['custom_model'] is not None
else models[self.session['tts_engine']][self.session['fine_tuned']]['voice']
)
if settings['voice_path'] is not None:
speaker = re.sub(r'\.wav$', '', os.path.basename(settings['voice_path']))
tts = (loaded_tts.get(self.tts_key) or {}).get('engine', False)
if tts:
if sentence[-1].isalnum():
sentence = f'{sentence} —'
if sentence == TTS_SML['break']:
break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100))) # 0.4 to 0.7 seconds
self.audio_segments.append(break_tensor.clone())
return True
elif sentence == TTS_SML['pause']:
pause_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(1.0, 1.8) * 100) / 100))) # 1.0 to 1.8 seconds
self.audio_segments.append(pause_tensor.clone())
return True
else:
if self.session['tts_engine'] == TTS_ENGINES['NEW_TTS']:
audio_sentence = NEW_TTS.CONVERT() # audio_sentence must be torch.Tensor or (list, tuple) or np.ndarray
if is_audio_data_valid(audio_sentence):
sourceTensor = self._tensor_type(audio_sentence)
audio_tensor = sourceTensor.clone().detach().unsqueeze(0).cpu()
if sentence[-1].isalnum() or sentence[-1] == '—':
audio_tensor = trim_audio(audio_tensor.squeeze(), settings['samplerate'], 0.003, trim_audio_buffer).unsqueeze(0)
self.audio_segments.append(audio_tensor)
if not re.search(r'\w$', sentence, flags=re.UNICODE):
break_tensor = torch.zeros(1, int(settings['samplerate'] * (int(np.random.uniform(0.3, 0.6) * 100) / 100)))
self.audio_segments.append(break_tensor.clone())
if self.audio_segments:
audio_tensor = torch.cat(self.audio_segments, dim=-1)
start_time = self.sentences_total_time
duration = audio_tensor.shape[-1] / settings['samplerate']
end_time = start_time + duration
self.sentences_total_time = end_time
sentence_obj = {
"start": start_time,
"end": end_time,
"text": sentence,
"resume_check": self.sentence_idx
}
self.sentence_idx = append_sentence2vtt(sentence_obj, self.vtt_path)
if self.sentence_idx:
torchaudio.save(final_sentence_file, audio_tensor, settings['samplerate'], format=default_audio_proc_format)
del audio_tensor
self.audio_segments = []
if os.path.exists(final_sentence_file):
return True
else:
error = f"Cannot create {final_sentence_file}"
print(error)
else:
error = f"convert() error: {self.session['tts_engine']} is None"
print(error)
except Exception as e:
error = f'Coquit.convert(): {e}'
raise ValueError(e)
return False