Spaces:
Runtime error
Runtime error
File size: 4,595 Bytes
fb42ae8 71579f2 364b411 2dacdf9 fb42ae8 2dacdf9 364b411 fb42ae8 364b411 80bc30e 92a7105 af59fff 80bc30e cf0cbad 2dacdf9 80bc30e 71579f2 cf0cbad 80bc30e cf0cbad 80bc30e cf0cbad 80bc30e 2dacdf9 af59fff cf0cbad 2dacdf9 af59fff cf0cbad 80bc30e fb42ae8 80bc30e af59fff 80bc30e cf0cbad 80bc30e 364b411 80bc30e 364b411 80bc30e 364b411 92a7105 364b411 80bc30e 364b411 cf0cbad fb42ae8 cf0cbad 80bc30e fb42ae8 80bc30e fb42ae8 af59fff fb42ae8 af59fff 80bc30e 364b411 80bc30e fb42ae8 71579f2 fb42ae8 364b411 80bc30e 364b411 af59fff 364b411 80bc30e 364b411 80bc30e 364b411 cf0cbad 80bc30e cf0cbad 80bc30e cf0cbad 364b411 80bc30e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
from transformers import (
AutoTokenizer,
AutoProcessor,
BarkModel,
pipeline,
AutoModelForSequenceClassification,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import scipy.io.wavfile as wavfile
import uuid
import os
from io import BytesIO
import soundfile as sf
from typing import Optional
# FastAPI instance
app = FastAPI(title="Kinyarwanda Engine", version="1.0")
# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000
ASR_MODEL_PATH = "lucio/wav2vec2-large-xlsr-kinyarwanda"
# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)
# Load models
try:
# TTS
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
processor = AutoProcessor.from_pretrained(MODEL_PATH)
model = BarkModel.from_pretrained(MODEL_PATH)
# Sentiment
sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
sentiment_pipeline = pipeline(
"sentiment-analysis",
model=sentiment_model,
tokenizer=sentiment_tokenizer,
truncation=True,
max_length=512
)
# STT
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
except Exception as e:
raise RuntimeError(f"Model initialization failed: {e}")
# Request schemas
class TTSRequest(BaseModel):
text: str
class SentimentRequest(BaseModel):
text: str
class LegalDocRequest(BaseModel):
text: str
domain: Optional[str] = "general"
# Root route
@app.get("/")
def root():
return {"message": "Welcome to Kinyarwanda Engine"}
# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
try:
inputs = processor(request.text, return_tensors="pt").to(device)
with torch.no_grad():
audio_array = model.generate(**inputs)
audio_data = audio_array.cpu().numpy().squeeze()
buffer = BytesIO()
wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
buffer.seek(0)
return StreamingResponse(
buffer,
media_type="audio/wav",
headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
# Speech-to-Text endpoint
@app.post("/stt/")
def speech_to_text(audio_file: UploadFile = File(...)):
try:
audio_bytes = audio_file.file.read()
audio, sample_rate = sf.read(BytesIO(audio_bytes))
# Resample if necessary
if sample_rate != 16000:
import librosa
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)
with torch.no_grad():
logits = asr_model(inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.batch_decode(predicted_ids)[0]
return {"transcription": transcription}
except Exception as e:
raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")
# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
try:
result = sentiment_pipeline(request.text)
return {"result": result}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")
# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
try:
keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
found = [kw for kw in keywords if kw in request.text.lower()]
return {
"identified_keywords": found,
"domain": request.domain,
"status": "success"
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")
|