File size: 5,411 Bytes
216406c
 
91180fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216406c
91180fb
 
216406c
91180fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216406c
91180fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os

os.environ["HF_HOME"] = "/tmp"
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
os.environ["TORCH_HOME"] = "/tmp"
os.environ["XDG_CACHE_HOME"] = "/tmp"

import io
import re
import math
import numpy as np
import scipy.io.wavfile
import torch
from fastapi import FastAPI, Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from transformers import VitsModel, AutoTokenizer

app = FastAPI()

model = VitsModel.from_pretrained("Somali-tts/somali_tts_model")
tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

number_words = {
    0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan",
    6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban",
    11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex",
    14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix",
    17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal",
    20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton",
    60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan",
    100: "boqol", 1000: "kun"
}

def number_to_words(number: int) -> str:
    if number < 20:
        return number_words[number]
    elif number < 100:
        tens, unit = divmod(number, 10)
        return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "")
    elif number < 1000:
        hundreds, remainder = divmod(number, 100)
        part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol"
        if remainder:
            part += " iyo " + number_to_words(remainder)
        return part
    elif number < 1000000:
        thousands, remainder = divmod(number, 1000)
        words = []
        if thousands == 1:
            words.append("kun")
        else:
            words.append(number_to_words(thousands) + " kun")
        if remainder:
            words.append("iyo " + number_to_words(remainder))
        return " ".join(words)
    elif number < 1000000000:
        millions, remainder = divmod(number, 1000000)
        words = []
        if millions == 1:
            words.append("milyan")
        else:
            words.append(number_to_words(millions) + " milyan")
        if remainder:
            words.append(number_to_words(remainder))
        return " ".join(words)
    else:
        return str(number)

def normalize_text(text: str) -> str:
    numbers = re.findall(r'\d+', text)
    for num in numbers:
        text = text.replace(num, number_to_words(int(num)))
    text = text.replace("KH", "qa").replace("Z", "S")
    text = text.replace("SH", "SHa'a").replace("DH", "Dha'a")
    text = text.replace("ZamZam", "SamSam")
    return text

def waveform_to_wav_bytes(waveform: torch.Tensor, sample_rate: int = 22050) -> bytes:
    np_waveform = waveform.cpu().numpy()
    if np_waveform.ndim == 3:
        np_waveform = np_waveform[0]
    if np_waveform.ndim == 2:
        np_waveform = np_waveform.mean(axis=0)
    np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32)
    pcm_waveform = (np_waveform * 32767).astype(np.int16)
    buf = io.BytesIO()
    scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
    buf.seek(0)
    return buf.read()

class TextIn(BaseModel):
    inputs: str

@app.post("/synthesize")
async def synthesize_post(data: TextIn):
    text = normalize_text(data.inputs)
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        output = model(**inputs)
    if hasattr(output, "waveform"):
        waveform = output.waveform
    elif isinstance(output, dict) and "waveform" in output:
        waveform = output["waveform"]
    elif isinstance(output, (tuple, list)):
        waveform = output[0]
    else:
        return {"error": "Waveform not found in model output"}
    sample_rate = getattr(model.config, "sampling_rate", 22050)
    wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
    return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")

@app.get("/synthesize")
async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)):
    if test:
        duration_s = 2.0
        sample_rate = 22050
        t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False)
        freq = 440
        waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32)
        pcm_waveform = (waveform * 32767).astype(np.int16)
        buf = io.BytesIO()
        scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform)
        buf.seek(0)
        return StreamingResponse(buf, media_type="audio/wav")
    normalized = normalize_text(text)
    inputs = tokenizer(normalized, return_tensors="pt").to(device)
    with torch.no_grad():
        output = model(**inputs)
    if hasattr(output, "waveform"):
        waveform = output.waveform
    elif isinstance(output, dict) and "waveform" in output:
        waveform = output["waveform"]
    elif isinstance(output, (tuple, list)):
        waveform = output[0]
    else:
        return {"error": "Waveform not found in model output"}
    sample_rate = getattr(model.config, "sampling_rate", 22050)
    wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate)
    return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")