Malay-VITS / app.py
huseinzol05's picture
initial
eb83dc0
raw
history blame
2.94 kB
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from huggingface_hub import snapshot_download
from malaya_speech.torch_model.vits.model_infer import SynthesizerTrn
from malaya_speech.torch_model.vits.commons import intersperse
from malaya_speech.utils.text import TTS_SYMBOLS
from malaya_speech.tts import load_text_ids
from malaya_speech.utils.astype import float_to_int
import gradio as gr
import torch
import os
import json
try:
from malaya_boilerplate.hparams import HParams
except BaseException:
from malaya_boilerplate.train.config import HParams
speaker_id = {
'Husein': 0,
'Shafiqah Idayu': 1,
}
normalizer = load_text_ids(pad_to = None, understand_punct = True, is_lower = False)
folder = snapshot_download(repo_id="malaysia-ai/malay-VITS-multispeaker")
with open(os.path.join(folder, 'config.json')) as fopen:
hps = HParams(**json.load(fopen))
model = SynthesizerTrn(
len(TTS_SYMBOLS),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).eval()
model.load_state_dict(torch.load(os.path.join(folder, 'model.pth'), map_location='cpu'))
def tts(text, speaker, temperature, length_ratio):
if len(text) < 1:
raise gr.Error('input text must longer than 0')
if speaker not in speaker_id:
raise gr.Error('speaker is not available')
t, ids = normalizer.normalize(text, add_fullstop = True)
if hps.data.add_blank:
ids = intersperse(ids, 0)
ids = torch.LongTensor(ids)
ids_lengths = torch.LongTensor([ids.size(0)])
ids = ids.unsqueeze(0)
sid = torch.tensor([speaker_id[speaker]])
with torch.no_grad():
audio = model.infer(
ids,
ids_lengths,
noise_scale=0.0,
noise_scale_w=0.0,
length_scale=1.0,
sid=sid,
)
y_ = audio[0].numpy()
data = float_to_int(y_[0, 0])
return (22050, data)
demo = gr.Interface(
fn=tts,
inputs=[
gr.components.Textbox(label='Text'),
gr.components.Dropdown(label='Available speakers', choices=list(speaker_id.keys()), value = 'Husein'),
gr.Slider(0.0, 1.0, value=0.6666, label='temperature, changing this will manipulate pitch'),
gr.Slider(0.0, 3.0, value=1.0, label='length ratio, changing this will manipulate duration output'),
],
outputs=['audio'],
examples=[
['Syed Saddiq berkata, mereka seharusnya mengingati bahawa semasa menjadi Perdana Menteri Pakatan Harapan', 'Husein', 0.6666, 1.0],
['Shah Alam - Pertubuhan Kebajikan Anak Bersatu Selangor bersetuju pihak kerajaan mewujudkan Suruhanjaya Siasatan Diraja untuk menyiasat isu kartel daging.', 'Shafiqah Idayu', 0.6666, 1.0],
],
cache_examples=False,
title='End-to-End Malay TTS using VITS',
)
demo.queue().launch(server_name='0.0.0.0')