# Environment settings import os os.environ["HF_HOME"] = "/tmp" os.environ["TRANSFORMERS_CACHE"] = "/tmp" os.environ["TORCH_HOME"] = "/tmp" os.environ["XDG_CACHE_HOME"] = "/tmp" import io import re import math import numpy as np import scipy.io.wavfile import torch from fastapi import FastAPI, Query from fastapi.responses import StreamingResponse from pydantic import BaseModel from transformers import VitsModel, AutoTokenizer app = FastAPI() model = VitsModel.from_pretrained("Somali-tts/somali_tts_model") tokenizer = AutoTokenizer.from_pretrained("saleolow/somali-mms-tts") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() number_words = { 0: "eber", 1: "koow", 2: "labo", 3: "seddex", 4: "afar", 5: "shan", 6: "lix", 7: "todobo", 8: "sideed", 9: "sagaal", 10: "toban", 11: "toban iyo koow", 12: "toban iyo labo", 13: "toban iyo seddex", 14: "toban iyo afar", 15: "toban iyo shan", 16: "toban iyo lix", 17: "toban iyo todobo", 18: "toban iyo sideed", 19: "toban iyo sagaal", 20: "labaatan", 30: "sodon", 40: "afartan", 50: "konton", 60: "lixdan", 70: "todobaatan", 80: "sideetan", 90: "sagaashan", 100: "boqol", 1000: "kun" } shortcut_map = { "asc": "asalaamu caleykum", "wcs": "wacaleykum salaam", "fcn": "fiican", "xld": "xaaladda ka waran", "kwrn": "kawaran", "scw": "salalaahu caleyhi wa salam", "alx": "alxamdu lilaahi", "m.a": "maasha allah", "sthy": "side tahey", "sxp": "saaxiib" } country_map = { "somalia": "Soomaaliya", "ethiopia": "Itoobiya", "kenya": "Kenya", "djibouti": "Jabuuti", "sudan": "Suudaan", "Yeman": "yemaan", "uganda": "Ugaandha", "tanzania": "Tansaaniya", "egypt": "Masar", "libya": "Liibiya", "algeria": "Aljeeriya", "morocco": "Morooko", "tunisia": "Tuniisiya", "eritrea": "Eriteriya", "malawi": "Malaawi", "English": "ingiriis", "Spain": "isbeen", "Brazil": "baraasiil", "niger": "Niyjer", "Italy": "itaaliya", "united states": "Maraykanka", "china": "Shiinaha", "india": "Hindiya", "russia": "Ruushka", "Saudi Arabia": "Sucuudi Carabiya", "germany": "Jarmalka", "france": "Faransiiska", "japan": "Jabaan", "canada": "Kanada", "australia": "Australia" } def number_to_words(number): number = int(number) if number < 20: return number_words[number] elif number < 100: tens, unit = divmod(number, 10) return number_words[tens * 10] + (" iyo " + number_words[unit] if unit else "") elif number < 1000: hundreds, remainder = divmod(number, 100) part = (number_words[hundreds] + " boqol") if hundreds > 1 else "boqol" if remainder: part += " iyo " + number_to_words(remainder) return part elif number < 1000000: thousands, remainder = divmod(number, 1000) words = [number_to_words(thousands) + " kun" if thousands > 1 else "kun"] if remainder: words.append("iyo " + number_to_words(remainder)) return " ".join(words) elif number < 1000000000: millions, remainder = divmod(number, 1000000) words = [number_to_words(millions) + " milyan" if millions > 1 else "milyan"] if remainder: words.append(number_to_words(remainder)) return " ".join(words) else: return str(number) def normalize_text(text): text = re.sub(r'(?i)(? bytes: np_waveform = waveform.cpu().numpy() if np_waveform.ndim == 3: np_waveform = np_waveform[0] if np_waveform.ndim == 2: np_waveform = np_waveform.mean(axis=0) np_waveform = np.clip(np_waveform, -1.0, 1.0).astype(np.float32) pcm_waveform = (np_waveform * 32767).astype(np.int16) buf = io.BytesIO() scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) buf.seek(0) return buf.read() class TextIn(BaseModel): inputs: str @app.post("/synthesize") async def synthesize_post(data: TextIn): paragraphs = [p.strip() for p in data.inputs.split('\n') if p.strip()] sample_rate = getattr(model.config, "sampling_rate", 22050) all_waveforms = [] for paragraph in paragraphs: normalized = normalize_text(paragraph) inputs = tokenizer(normalized, return_tensors="pt").to(device) with torch.no_grad(): output = model(**inputs) waveform = ( output.waveform if hasattr(output, "waveform") else output["waveform"] if isinstance(output, dict) and "waveform" in output else output[0] if isinstance(output, (tuple, list)) else None ) if waveform is None: continue all_waveforms.append(waveform) silence = torch.zeros(1, sample_rate).to(waveform.device) all_waveforms.append(silence) if not all_waveforms: return {"error": "No audio generated."} final_waveform = torch.cat(all_waveforms, dim=-1) wav_bytes = waveform_to_wav_bytes(final_waveform, sample_rate=sample_rate) return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav") @app.get("/synthesize") async def synthesize_get(text: str = Query(..., description="Text to synthesize"), test: bool = Query(False)): if test: paragraphs = text.count("\n") + 1 duration_s = paragraphs * 6 sample_rate = 22050 t = np.linspace(0, duration_s, int(sample_rate * duration_s), endpoint=False) freq = 440 waveform = 0.5 * np.sin(2 * math.pi * freq * t).astype(np.float32) pcm_waveform = (waveform * 32767).astype(np.int16) buf = io.BytesIO() scipy.io.wavfile.write(buf, rate=sample_rate, data=pcm_waveform) buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") normalized = normalize_text(text) inputs = tokenizer(normalized, return_tensors="pt").to(device) with torch.no_grad(): output = model(**inputs) waveform = ( output.waveform if hasattr(output, "waveform") else output["waveform"] if isinstance(output, dict) and "waveform" in output else output[0] if isinstance(output, (tuple, list)) else None ) if waveform is None: return {"error": "Waveform not found in model output"} sample_rate = getattr(model.config, "sampling_rate", 22050) wav_bytes = waveform_to_wav_bytes(waveform, sample_rate=sample_rate) return StreamingResponse(io.BytesIO(wav_bytes), media_type="audio/wav")