black44's picture
Update app.py
71579f2 verified
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoProcessor,
BarkModel,
pipeline,
AutoModelForSequenceClassification,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import scipy.io.wavfile as wavfile
import uuid
import os
from io import BytesIO
import soundfile as sf
from typing import Optional
# FastAPI instance
app = FastAPI(title="Kinyarwanda Engine", version="1.0")
# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000
ASR_MODEL_PATH = "lucio/wav2vec2-large-xlsr-kinyarwanda"
# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
# Load models
try:
# TTS
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Sentiment
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer,
truncation=True,
max_length=512
)
# STT
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {e}")
# Request schemas
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
# Root route
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda Engine"}
# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
audio_array = model.generate(**inputs)
audio_data = audio_array.cpu().numpy().squeeze()
buffer = BytesIO()
wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
buffer.seek(0)
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
# Speech-to-Text endpoint
@app.post("/stt/")
def speech_to_text(audio_file: UploadFile = File(...)):
try:
audio_bytes = audio_file.file.read()
audio, sample_rate = sf.read(BytesIO(audio_bytes))
# Resample if necessary
if sample_rate != 16000:
import librosa
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)
with torch.no_grad():
logits = asr_model(inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)[0]
return {"transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")
# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")
# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found,
"domain": request.domain,
"status": "success"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")