from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.responses import StreamingResponse from pydantic import BaseModel import torch from transformers import ( AutoTokenizer, AutoProcessor, BarkModel, pipeline, AutoModelForSequenceClassification, Wav2Vec2Processor, Wav2Vec2ForCTC ) import scipy.io.wavfile as wavfile import uuid import os from io import BytesIO import soundfile as sf from typing import Optional # FastAPI instance app = FastAPI(title="Kinyarwanda Engine", version="1.0") # Config MODEL_PATH = "/app/models/suno-bark" SENTIMENT_MODEL_PATH = "/app/models/sentiment" SAMPLE_RATE = 24000 ASR_MODEL_PATH = "lucio/wav2vec2-large-xlsr-kinyarwanda" # Ensure working directory for audio AUDIO_DIR = "/tmp/audio" os.makedirs(AUDIO_DIR, exist_ok=True) # Load models try: # TTS tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) processor = AutoProcessor.from_pretrained(MODEL_PATH) model = BarkModel.from_pretrained(MODEL_PATH) # Sentiment sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH) sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH) sentiment_pipeline = pipeline( "sentiment-analysis", model=sentiment_model, tokenizer=sentiment_tokenizer, truncation=True, max_length=512 ) # STT asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH) asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH) asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) # Device config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) except Exception as e: raise RuntimeError(f"Model initialization failed: {e}") # Request schemas class TTSRequest(BaseModel): text: str class SentimentRequest(BaseModel): text: str class LegalDocRequest(BaseModel): text: str domain: Optional[str] = "general" # Root route @app.get("/") def root(): return {"message": "Welcome to Kinyarwanda Engine"} # Text-to-Speech endpoint @app.post("/tts/") def text_to_speech(request: TTSRequest): try: inputs = processor(request.text, return_tensors="pt").to(device) with torch.no_grad(): audio_array = model.generate(**inputs) audio_data = audio_array.cpu().numpy().squeeze() buffer = BytesIO() wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data) buffer.seek(0) return StreamingResponse( buffer, media_type="audio/wav", headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"} ) except Exception as e: raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}") # Speech-to-Text endpoint @app.post("/stt/") def speech_to_text(audio_file: UploadFile = File(...)): try: audio_bytes = audio_file.file.read() audio, sample_rate = sf.read(BytesIO(audio_bytes)) # Resample if necessary if sample_rate != 16000: import librosa audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000) sample_rate = 16000 inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device) with torch.no_grad(): logits = asr_model(inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = asr_processor.batch_decode(predicted_ids)[0] return {"transcription": transcription} except Exception as e: raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}") # Sentiment Analysis endpoint @app.post("/sentiment/") def analyze_sentiment(request: SentimentRequest): try: result = sentiment_pipeline(request.text) return {"result": result} except Exception as e: raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}") # Legal Parsing endpoint @app.post("/legal-parse/") def parse_legal_document(request: LegalDocRequest): try: keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"] found = [kw for kw in keywords if kw in request.text.lower()] return { "identified_keywords": found, "domain": request.domain, "status": "success" } except Exception as e: raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")