Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoProcessor, | |
BarkModel, | |
pipeline, | |
AutoModelForSequenceClassification, | |
Wav2Vec2Processor, | |
Wav2Vec2ForCTC | |
) | |
import scipy.io.wavfile as wavfile | |
import uuid | |
import os | |
from io import BytesIO | |
import soundfile as sf | |
from typing import Optional | |
# FastAPI instance | |
app = FastAPI(title="Kinyarwanda Engine", version="1.0") | |
# Config | |
MODEL_PATH = "/app/models/suno-bark" | |
SENTIMENT_MODEL_PATH = "/app/models/sentiment" | |
SAMPLE_RATE = 24000 | |
ASR_MODEL_PATH = "lucio/wav2vec2-large-xlsr-kinyarwanda" | |
# Ensure working directory for audio | |
AUDIO_DIR = "/tmp/audio" | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
# Load models | |
try: | |
# TTS | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
model = BarkModel.from_pretrained(MODEL_PATH) | |
# Sentiment | |
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH) | |
sentiment_pipeline = pipeline( | |
"sentiment-analysis", | |
model=sentiment_model, | |
tokenizer=sentiment_tokenizer, | |
truncation=True, | |
max_length=512 | |
) | |
# STT | |
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH) | |
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH) | |
asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) | |
# Device config | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
except Exception as e: | |
raise RuntimeError(f"Model initialization failed: {e}") | |
# Request schemas | |
class TTSRequest(BaseModel): | |
text: str | |
class SentimentRequest(BaseModel): | |
text: str | |
class LegalDocRequest(BaseModel): | |
text: str | |
domain: Optional[str] = "general" | |
# Root route | |
def root(): | |
return {"message": "Welcome to Kinyarwanda Engine"} | |
# Text-to-Speech endpoint | |
def text_to_speech(request: TTSRequest): | |
try: | |
inputs = processor(request.text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
audio_array = model.generate(**inputs) | |
audio_data = audio_array.cpu().numpy().squeeze() | |
buffer = BytesIO() | |
wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data) | |
buffer.seek(0) | |
return StreamingResponse( | |
buffer, | |
media_type="audio/wav", | |
headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"} | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") | |
# Speech-to-Text endpoint | |
def speech_to_text(audio_file: UploadFile = File(...)): | |
try: | |
audio_bytes = audio_file.file.read() | |
audio, sample_rate = sf.read(BytesIO(audio_bytes)) | |
# Resample if necessary | |
if sample_rate != 16000: | |
import librosa | |
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) | |
sample_rate = 16000 | |
inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device) | |
with torch.no_grad(): | |
logits = asr_model(inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = asr_processor.batch_decode(predicted_ids)[0] | |
return {"transcription": transcription} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}") | |
# Sentiment Analysis endpoint | |
def analyze_sentiment(request: SentimentRequest): | |
try: | |
result = sentiment_pipeline(request.text) | |
return {"result": result} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}") | |
# Legal Parsing endpoint | |
def parse_legal_document(request: LegalDocRequest): | |
try: | |
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] | |
found = [kw for kw in keywords if kw in request.text.lower()] | |
return { | |
"identified_keywords": found, | |
"domain": request.domain, | |
"status": "success" | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}") | |