Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
-
from fastapi.responses import
|
3 |
from pydantic import BaseModel
|
4 |
import torch
|
5 |
from transformers import (
|
@@ -7,11 +7,15 @@ from transformers import (
|
|
7 |
AutoProcessor,
|
8 |
BarkModel,
|
9 |
pipeline,
|
10 |
-
AutoModelForSequenceClassification
|
|
|
|
|
11 |
)
|
12 |
import scipy.io.wavfile as wavfile
|
13 |
import uuid
|
14 |
import os
|
|
|
|
|
15 |
from typing import Optional
|
16 |
|
17 |
# FastAPI instance
|
@@ -21,6 +25,7 @@ app = FastAPI(title="Kinyarwanda Engine", version="1.0")
|
|
21 |
MODEL_PATH = "/app/models/suno-bark"
|
22 |
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
|
23 |
SAMPLE_RATE = 24000
|
|
|
24 |
|
25 |
# Ensure working directory for audio
|
26 |
AUDIO_DIR = "/tmp/audio"
|
@@ -44,6 +49,11 @@ try:
|
|
44 |
max_length=512
|
45 |
)
|
46 |
|
|
|
|
|
|
|
|
|
|
|
47 |
# Device config
|
48 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
model.to(device)
|
@@ -70,27 +80,44 @@ def root():
|
|
70 |
# Text-to-Speech endpoint
|
71 |
@app.post("/tts/")
|
72 |
def text_to_speech(request: TTSRequest):
|
73 |
-
output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav")
|
74 |
-
|
75 |
try:
|
76 |
inputs = processor(request.text, return_tensors="pt").to(device)
|
|
|
77 |
with torch.no_grad():
|
78 |
audio_array = model.generate(**inputs)
|
79 |
|
80 |
-
|
|
|
|
|
|
|
81 |
|
82 |
-
return
|
83 |
-
|
84 |
media_type="audio/wav",
|
85 |
-
filename=
|
86 |
)
|
87 |
|
88 |
except Exception as e:
|
89 |
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
|
90 |
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
# Sentiment Analysis endpoint
|
96 |
@app.post("/sentiment/")
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, UploadFile, File
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
from pydantic import BaseModel
|
4 |
import torch
|
5 |
from transformers import (
|
|
|
7 |
AutoProcessor,
|
8 |
BarkModel,
|
9 |
pipeline,
|
10 |
+
AutoModelForSequenceClassification,
|
11 |
+
Wav2Vec2Processor,
|
12 |
+
Wav2Vec2ForCTC
|
13 |
)
|
14 |
import scipy.io.wavfile as wavfile
|
15 |
import uuid
|
16 |
import os
|
17 |
+
from io import BytesIO
|
18 |
+
import soundfile as sf
|
19 |
from typing import Optional
|
20 |
|
21 |
# FastAPI instance
|
|
|
25 |
MODEL_PATH = "/app/models/suno-bark"
|
26 |
SENTIMENT_MODEL_PATH = "/app/models/sentiment"
|
27 |
SAMPLE_RATE = 24000
|
28 |
+
ASR_MODEL_PATH = "jonatasgrosman/wav2vec2-large-xlsr-53-Kinyarwanda"
|
29 |
|
30 |
# Ensure working directory for audio
|
31 |
AUDIO_DIR = "/tmp/audio"
|
|
|
49 |
max_length=512
|
50 |
)
|
51 |
|
52 |
+
# STT
|
53 |
+
asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
|
54 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
|
55 |
+
asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
56 |
+
|
57 |
# Device config
|
58 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
model.to(device)
|
|
|
80 |
# Text-to-Speech endpoint
|
81 |
@app.post("/tts/")
|
82 |
def text_to_speech(request: TTSRequest):
|
|
|
|
|
83 |
try:
|
84 |
inputs = processor(request.text, return_tensors="pt").to(device)
|
85 |
+
|
86 |
with torch.no_grad():
|
87 |
audio_array = model.generate(**inputs)
|
88 |
|
89 |
+
audio_data = audio_array.cpu().numpy().squeeze()
|
90 |
+
buffer = BytesIO()
|
91 |
+
wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
|
92 |
+
buffer.seek(0)
|
93 |
|
94 |
+
return StreamingResponse(
|
95 |
+
buffer,
|
96 |
media_type="audio/wav",
|
97 |
+
headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
|
98 |
)
|
99 |
|
100 |
except Exception as e:
|
101 |
raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
|
102 |
|
103 |
+
# Speech-to-Text endpoint
|
104 |
+
@app.post("/stt/")
|
105 |
+
def speech_to_text(audio_file: UploadFile = File(...)):
|
106 |
+
try:
|
107 |
+
audio_bytes = audio_file.file.read()
|
108 |
+
audio, sample_rate = sf.read(BytesIO(audio_bytes))
|
109 |
+
|
110 |
+
inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)
|
111 |
+
|
112 |
+
with torch.no_grad():
|
113 |
+
logits = asr_model(inputs).logits
|
114 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
115 |
+
|
116 |
+
transcription = asr_processor.batch_decode(predicted_ids)[0]
|
117 |
+
return {"transcription": transcription}
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")
|
121 |
|
122 |
# Sentiment Analysis endpoint
|
123 |
@app.post("/sentiment/")
|