black44 commited on
Commit
fb42ae8
·
verified ·
1 Parent(s): 92a7105

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -12
app.py CHANGED
@@ -1,5 +1,5 @@
1
- from fastapi import FastAPI, HTTPException
2
- from fastapi.responses import FileResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
@@ -7,11 +7,15 @@ from transformers import (
7
  AutoProcessor,
8
  BarkModel,
9
  pipeline,
10
- AutoModelForSequenceClassification
 
 
11
  )
12
  import scipy.io.wavfile as wavfile
13
  import uuid
14
  import os
 
 
15
  from typing import Optional
16
 
17
  # FastAPI instance
@@ -21,6 +25,7 @@ app = FastAPI(title="Kinyarwanda Engine", version="1.0")
21
  MODEL_PATH = "/app/models/suno-bark"
22
  SENTIMENT_MODEL_PATH = "/app/models/sentiment"
23
  SAMPLE_RATE = 24000
 
24
 
25
  # Ensure working directory for audio
26
  AUDIO_DIR = "/tmp/audio"
@@ -44,6 +49,11 @@ try:
44
  max_length=512
45
  )
46
 
 
 
 
 
 
47
  # Device config
48
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  model.to(device)
@@ -70,27 +80,44 @@ def root():
70
  # Text-to-Speech endpoint
71
  @app.post("/tts/")
72
  def text_to_speech(request: TTSRequest):
73
- output_file = os.path.join(AUDIO_DIR, f"tts_{uuid.uuid4().hex}.wav")
74
-
75
  try:
76
  inputs = processor(request.text, return_tensors="pt").to(device)
 
77
  with torch.no_grad():
78
  audio_array = model.generate(**inputs)
79
 
80
- wavfile.write(output_file, rate=SAMPLE_RATE, data=audio_array.cpu().numpy().squeeze())
 
 
 
81
 
82
- return FileResponse(
83
- output_file,
84
  media_type="audio/wav",
85
- filename=os.path.basename(output_file)
86
  )
87
 
88
  except Exception as e:
89
  raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
90
 
91
- finally:
92
- if os.path.exists(output_file):
93
- os.remove(output_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
  # Sentiment Analysis endpoint
96
  @app.post("/sentiment/")
 
1
+ from fastapi import FastAPI, HTTPException, UploadFile, File
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
  from pydantic import BaseModel
4
  import torch
5
  from transformers import (
 
7
  AutoProcessor,
8
  BarkModel,
9
  pipeline,
10
+ AutoModelForSequenceClassification,
11
+ Wav2Vec2Processor,
12
+ Wav2Vec2ForCTC
13
  )
14
  import scipy.io.wavfile as wavfile
15
  import uuid
16
  import os
17
+ from io import BytesIO
18
+ import soundfile as sf
19
  from typing import Optional
20
 
21
  # FastAPI instance
 
25
  MODEL_PATH = "/app/models/suno-bark"
26
  SENTIMENT_MODEL_PATH = "/app/models/sentiment"
27
  SAMPLE_RATE = 24000
28
+ ASR_MODEL_PATH = "jonatasgrosman/wav2vec2-large-xlsr-53-Kinyarwanda"
29
 
30
  # Ensure working directory for audio
31
  AUDIO_DIR = "/tmp/audio"
 
49
  max_length=512
50
  )
51
 
52
+ # STT
53
+ asr_processor = Wav2Vec2Processor.from_pretrained(ASR_MODEL_PATH)
54
+ asr_model = Wav2Vec2ForCTC.from_pretrained(ASR_MODEL_PATH)
55
+ asr_model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
56
+
57
  # Device config
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
  model.to(device)
 
80
  # Text-to-Speech endpoint
81
  @app.post("/tts/")
82
  def text_to_speech(request: TTSRequest):
 
 
83
  try:
84
  inputs = processor(request.text, return_tensors="pt").to(device)
85
+
86
  with torch.no_grad():
87
  audio_array = model.generate(**inputs)
88
 
89
+ audio_data = audio_array.cpu().numpy().squeeze()
90
+ buffer = BytesIO()
91
+ wavfile.write(buffer, rate=SAMPLE_RATE, data=audio_data)
92
+ buffer.seek(0)
93
 
94
+ return StreamingResponse(
95
+ buffer,
96
  media_type="audio/wav",
97
+ headers={"Content-Disposition": f"attachment; filename=tts_{uuid.uuid4().hex}.wav"}
98
  )
99
 
100
  except Exception as e:
101
  raise HTTPException(status_code=500, detail=f"TTS generation failed: {str(e)}")
102
 
103
+ # Speech-to-Text endpoint
104
+ @app.post("/stt/")
105
+ def speech_to_text(audio_file: UploadFile = File(...)):
106
+ try:
107
+ audio_bytes = audio_file.file.read()
108
+ audio, sample_rate = sf.read(BytesIO(audio_bytes))
109
+
110
+ inputs = asr_processor(audio, sampling_rate=sample_rate, return_tensors="pt", padding=True).input_values.to(device)
111
+
112
+ with torch.no_grad():
113
+ logits = asr_model(inputs).logits
114
+ predicted_ids = torch.argmax(logits, dim=-1)
115
+
116
+ transcription = asr_processor.batch_decode(predicted_ids)[0]
117
+ return {"transcription": transcription}
118
+
119
+ except Exception as e:
120
+ raise HTTPException(status_code=500, detail=f"STT failed: {str(e)}")
121
 
122
  # Sentiment Analysis endpoint
123
  @app.post("/sentiment/")