File size: 4,595 Bytes
fb42ae8
71579f2
364b411
 
2dacdf9
 
 
 
 
fb42ae8
 
 
2dacdf9
364b411
 
 
fb42ae8
 
364b411
 
80bc30e
92a7105
af59fff
80bc30e
cf0cbad
2dacdf9
80bc30e
71579f2
cf0cbad
80bc30e
 
 
 
 
cf0cbad
80bc30e
cf0cbad
 
 
80bc30e
 
2dacdf9
 
af59fff
cf0cbad
2dacdf9
 
af59fff
 
cf0cbad
80bc30e
fb42ae8
 
 
 
 
80bc30e
af59fff
 
80bc30e
cf0cbad
80bc30e
364b411
80bc30e
364b411
 
 
 
 
 
 
 
 
 
80bc30e
364b411
 
92a7105
364b411
80bc30e
364b411
 
 
cf0cbad
fb42ae8
cf0cbad
80bc30e
 
fb42ae8
 
 
 
80bc30e
fb42ae8
 
af59fff
fb42ae8
af59fff
80bc30e
364b411
80bc30e
 
fb42ae8
 
 
 
 
 
 
71579f2
 
 
 
 
 
fb42ae8
 
 
 
 
 
 
 
 
 
 
364b411
80bc30e
364b411
 
 
af59fff
364b411
 
80bc30e
364b411
80bc30e
364b411
 
 
cf0cbad
80bc30e
cf0cbad
80bc30e
cf0cbad
 
 
364b411
80bc30e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
from transformers import (
    AutoTokenizer,
    AutoProcessor,
    BarkModel,
    pipeline,
    AutoModelForSequenceClassification,
    Wav2Vec2Processor,
    Wav2Vec2ForCTC
)
import scipy.io.wavfile as wavfile
import uuid
import os
from io import BytesIO
import soundfile as sf
from typing import Optional

# FastAPI instance
app = FastAPI(title="Kinyarwanda Engine", version="1.0")

# Config
MODEL_PATH = "/app/models/suno-bark"
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
SAMPLE_RATE = 24000
ASR_MODEL_PATH = "lucio/wav2vec2-large-xlsr-kinyarwanda"

# Ensure working directory for audio
AUDIO_DIR = "/tmp/audio"
os.makedirs(AUDIO_DIR, exist_ok=True)

# Load models
try:
    # TTS
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    processor = AutoProcessor.from_pretrained(MODEL_PATH)
    model = BarkModel.from_pretrained(MODEL_PATH)

    # Sentiment
    sentiment_tokenizer = AutoTokenizer.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_model = AutoModelForSequenceClassification.from_pretrained(SENTIMENT_MODEL_PATH)
    sentiment_pipeline = pipeline(
        "sentiment-analysis",
        model=sentiment_model,
        tokenizer=sentiment_tokenizer,
        truncation=True,
        max_length=512
    )

    # STT
    asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
    asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
    asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # Device config
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

except Exception as e:
    raise RuntimeError(f"Model initialization failed: {e}")

# Request schemas
class TTSRequest(BaseModel):
    text: str

class SentimentRequest(BaseModel):
    text: str

class LegalDocRequest(BaseModel):
    text: str
    domain: Optional[str] = "general"

# Root route
@app.get("/")
def root():
    return {"message": "Welcome to Kinyarwanda Engine"}

# Text-to-Speech endpoint
@app.post("/tts/")
def text_to_speech(request: TTSRequest):
    try:
        inputs = processor(request.text, return_tensors="pt").to(device)

        with torch.no_grad():
            audio_array = model.generate(**inputs)

        audio_data = audio_array.cpu().numpy().squeeze()
        buffer = BytesIO()
        wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
        buffer.seek(0)

        return StreamingResponse(
            buffer,
            media_type="audio/wav",
            headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
        )

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")

# Speech-to-Text endpoint
@app.post("/stt/")
def speech_to_text(audio_file: UploadFile = File(...)):
    try:
        audio_bytes = audio_file.file.read()
        audio, sample_rate = sf.read(BytesIO(audio_bytes))

        # Resample if necessary
        if sample_rate != 16000:
            import librosa
            audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
            sample_rate = 16000

        inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)

        with torch.no_grad():
            logits = asr_model(inputs).logits
            predicted_ids = torch.argmax(logits, dim=-1)

        transcription = asr_processor.batch_decode(predicted_ids)[0]
        return {"transcription": transcription}

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")

# Sentiment Analysis endpoint
@app.post("/sentiment/")
def analyze_sentiment(request: SentimentRequest):
    try:
        result = sentiment_pipeline(request.text)
        return {"result": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Sentiment analysis failed: {str(e)}")

# Legal Parsing endpoint
@app.post("/legal-parse/")
def parse_legal_document(request: LegalDocRequest):
    try:
        keywords = ["contract", "agreement", "party", "terms", "confidential", "jurisdiction"]
        found = [kw for kw in keywords if kw in request.text.lower()]
        return {
            "identified_keywords": found,
            "domain": request.domain,
            "status": "success"
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Legal parsing failed: {str(e)}")