Speech-to-Speech-FM / TTS_models.py
HanaeRateau's picture
first commit
e9c2890
raw
history blame
4.71 kB
from abc import ABC, abstractmethod
import io
import numpy as np
import torch
from transformers import pipeline
from datasets import load_dataset
class TTSModel:
def __init__(self, model_name):
self.hf_name = model_name
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
@abstractmethod
def synthesize(self, text):
pass
#####
####################################################
class SpeechT5(TTSModel):
def __init__(self, name="microsoft/speecht5_tts"):
super(SpeechT5, self).__init__(name)
self.synthesiser = pipeline("text-to-speech", model=self.hf_name, device=self.device)
self.embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
self.speaker_embedding = torch.tensor(self.embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def synthesize(self, text):
speech = self.synthesiser(text, forward_params={"speaker_embeddings": self.speaker_embedding})
print("[SpeechT5 - synthesize]", speech)
return (np.array(speech["audio"])* 32767).astype(np.int16) # return a numpy array of int to play
####################################################
# PENDING: NOT WORKING FROM HF
# from MeloTTS.melo.api import TTS as meloTTS
# import nltk
# class MeloTTS(TTSModel):
# def __init__(self, name="myshell-ai/MeloTTS-English"):
# super(MeloTTS, self).__init__(name)
# nltk.download('averaged_perceptron_tagger_eng')
# self.synthesiser = meloTTS(language='EN', device=self.device)
# self.speaker_ids = self.synthesiser.hps.data.spk2id
# def synthesize(self, text):
# speech = self.synthesiser.tts_to_file(text, self.speaker_ids['EN-Default'])
# print("[MeloTTS - synthesize]", speech)
# return speech
####################################################
class Bark(TTSModel):
def __init__(self, name="suno/bark"):
super(Bark, self).__init__(name)
self.synthesiser = pipeline("text-to-speech", model=self.hf_name, device=self.device)
def synthesize(self, text):
speech = self.synthesiser(text)
print("[Bark - synthesize]", speech)
return speech
####################################################
# pip install git+https://github.com/huggingface/parler-tts.git
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer
class ParlerTTS(TTSModel):
def __init__(self, name="parler-tts/parler-tts-large-v1"):
super(ParlerTTS, self).__init__(name)
self.description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
self.model = ParlerTTSForConditionalGeneration.from_pretrained(self.hf_name).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name)
# self.synthesiser = pipeline("text-to-speech", model=self.model, tokenizer=self.tokenizer, device=self.device)
def synthesize(self, text):
input_ids = self.tokenizer(self.description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
speech = generation.cpu().numpy().squeeze()
print("[ParlerTTS - synthesize]", speech)
return speech
####################################################
# PENDING: NOT WORKING FROM HF
# pip install coqui-tts
# https://github.com/idiap/coqui-ai-TTS
from TTS.api import TTS
class XTTS(TTSModel):
def __init__(self, name="tts_models/en/ljspeech/glow-tts"):
super(XTTS, self).__init__(name)
self.synthesiser = TTS(model_name=name, progress_bar=False).to(self.device)
# self.model = AutoModelForSequenceClassification.from_pretrained(self.hf_name).to(self.device)
# self.tokenizer = AutoTokenizer.from_pretrained(self.hf_name)
# self.synthesiser = pipeline("text-to-speech", model=self.model, tokenizer=self.tokenizer, device=self.device)
# self.synthesiser = pipeline("text-to-speech", model=self.hf_name, device=self.device)
def synthesize(self, text):
# input_ids = self.tokenizer(self.description, return_tensors="pt").input_ids.to(self.device)
# prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
# print("synthesizing ", text)
speech = self.synthesiser.tts(text=text)
print("[XTTS - synthesize]", len(speech), text)
return speech