Spaces:
Runtime error
Runtime error
Commit
·
4c65a2b
1
Parent(s):
c796506
test
Browse files- app.py +43 -136
- not.txt +10 -0
- requirements.txt +1 -0
app.py
CHANGED
@@ -3,144 +3,35 @@ import os
|
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
|
|
|
6 |
import gradio as gr
|
7 |
import numpy as np
|
|
|
|
|
8 |
from fastapi import FastAPI
|
9 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
10 |
from fastrtc import (
|
11 |
AdditionalOutputs,
|
12 |
ReplyOnPause,
|
13 |
Stream,
|
|
|
14 |
get_twilio_turn_credentials,
|
15 |
)
|
16 |
from fastrtc.utils import audio_to_bytes
|
17 |
from gradio.utils import get_space
|
|
|
18 |
from pydantic import BaseModel
|
19 |
-
import torch
|
20 |
-
|
21 |
-
# Import for open-source models
|
22 |
-
from transformers import (
|
23 |
-
AutoModelForSpeechSeq2Seq,
|
24 |
-
AutoProcessor,
|
25 |
-
AutoTokenizer,
|
26 |
-
AutoModelForCausalLM,
|
27 |
-
pipeline
|
28 |
-
)
|
29 |
-
import gc
|
30 |
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
32 |
curr_dir = Path(__file__).parent
|
33 |
|
34 |
-
|
35 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
-
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
37 |
-
|
38 |
-
print("Loading ASR model...")
|
39 |
-
asr_model_id = "openai/whisper-small"
|
40 |
-
asr_processor = AutoProcessor.from_pretrained(asr_model_id)
|
41 |
-
asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
42 |
-
asr_model_id,
|
43 |
-
torch_dtype=torch_dtype,
|
44 |
-
low_cpu_mem_usage=True,
|
45 |
-
use_safetensors=True
|
46 |
-
)
|
47 |
-
asr_model.to(device)
|
48 |
-
asr_pipe = pipeline(
|
49 |
-
"automatic-speech-recognition",
|
50 |
-
model=asr_model,
|
51 |
-
tokenizer=asr_processor.tokenizer,
|
52 |
-
feature_extractor=asr_processor.feature_extractor,
|
53 |
-
max_new_tokens=128,
|
54 |
-
chunk_length_s=30,
|
55 |
-
batch_size=16,
|
56 |
-
return_timestamps=False,
|
57 |
-
device=device,
|
58 |
-
)
|
59 |
|
60 |
-
# Load LLM (TinyLlama for lightweight operation)
|
61 |
-
print("Loading LLM model...")
|
62 |
-
llm_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
63 |
-
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
|
64 |
-
llm_model = AutoModelForCausalLM.from_pretrained(
|
65 |
-
llm_model_id,
|
66 |
-
torch_dtype=torch_dtype,
|
67 |
-
low_cpu_mem_usage=True,
|
68 |
-
use_safetensors=True
|
69 |
-
)
|
70 |
-
llm_model.to(device)
|
71 |
-
|
72 |
-
# Load TTS model (Piper TTS or CoquiTTS)
|
73 |
-
print("Loading TTS model...")
|
74 |
-
from TTS.api import TTS
|
75 |
-
tts_model = TTS("tts_models/en/ljspeech/tacotron2-DDC", gpu=torch.cuda.is_available())
|
76 |
-
|
77 |
-
# Free up memory after loading models
|
78 |
-
gc.collect()
|
79 |
-
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
80 |
-
|
81 |
-
def transcribe_audio(audio):
|
82 |
-
"""Transcribe audio using Whisper model"""
|
83 |
-
result = asr_pipe({"array": audio[1], "sampling_rate": audio[0]})
|
84 |
-
return result["text"]
|
85 |
-
|
86 |
-
def generate_llm_response(messages):
|
87 |
-
"""Generate response using TinyLlama model"""
|
88 |
-
# Format messages for TinyLlama
|
89 |
-
prompt = ""
|
90 |
-
for msg in messages:
|
91 |
-
if msg["role"] == "user":
|
92 |
-
prompt += f"<|user|>\n{msg['content']}\n<|assistant|>\n"
|
93 |
-
elif msg["role"] == "assistant":
|
94 |
-
prompt += f"{msg['content']}\n"
|
95 |
-
|
96 |
-
# Add final assistant token if not present
|
97 |
-
if not prompt.endswith("<|assistant|>\n"):
|
98 |
-
prompt += "<|assistant|>\n"
|
99 |
-
|
100 |
-
# Generate response
|
101 |
-
inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
|
102 |
-
outputs = llm_model.generate(
|
103 |
-
inputs.input_ids,
|
104 |
-
max_new_tokens=512,
|
105 |
-
do_sample=True,
|
106 |
-
temperature=0.7,
|
107 |
-
top_p=0.9,
|
108 |
-
)
|
109 |
-
response = llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
110 |
-
return response
|
111 |
-
|
112 |
-
def generate_speech(text):
|
113 |
-
"""Generate speech from text using TTS model"""
|
114 |
-
# Create a temporary file path
|
115 |
-
output_path = curr_dir / "temp_audio.wav"
|
116 |
-
|
117 |
-
# Generate audio
|
118 |
-
tts_model.tts_to_file(text=text, file_path=str(output_path))
|
119 |
-
|
120 |
-
# Read audio file
|
121 |
-
import wave
|
122 |
-
import numpy as np
|
123 |
-
|
124 |
-
with wave.open(str(output_path), 'rb') as wav_file:
|
125 |
-
# Get audio parameters
|
126 |
-
sample_rate = wav_file.getframerate()
|
127 |
-
n_frames = wav_file.getnframes()
|
128 |
-
n_channels = wav_file.getnchannels()
|
129 |
-
|
130 |
-
# Read audio data
|
131 |
-
data = wav_file.readframes(n_frames)
|
132 |
-
|
133 |
-
# Convert to numpy array
|
134 |
-
audio_data = np.frombuffer(data, dtype=np.int16)
|
135 |
-
if n_channels == 2: # Convert stereo to mono
|
136 |
-
audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
|
137 |
-
|
138 |
-
# Clean up temp file
|
139 |
-
if output_path.exists():
|
140 |
-
os.remove(output_path)
|
141 |
-
|
142 |
-
# Return audio data and sample rate
|
143 |
-
return (sample_rate, audio_data)
|
144 |
|
145 |
def response(
|
146 |
audio: tuple[int, np.ndarray],
|
@@ -148,28 +39,36 @@ def response(
|
|
148 |
):
|
149 |
chatbot = chatbot or []
|
150 |
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
154 |
chatbot.append({"role": "user", "content": prompt})
|
155 |
yield AdditionalOutputs(chatbot)
|
156 |
messages.append({"role": "user", "content": prompt})
|
157 |
-
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
chatbot.append({"role": "assistant", "content": response_text})
|
161 |
-
|
162 |
-
# Generate speech
|
163 |
start = time.time()
|
|
|
164 |
print("starting tts", start)
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
|
172 |
-
# Set up Gradio chatbot interface
|
173 |
chatbot = gr.Chatbot(type="messages")
|
174 |
stream = Stream(
|
175 |
modality="audio",
|
@@ -183,17 +82,21 @@ stream = Stream(
|
|
183 |
time_limit=90 if get_space() else None,
|
184 |
)
|
185 |
|
|
|
186 |
class Message(BaseModel):
|
187 |
role: str
|
188 |
content: str
|
189 |
|
|
|
190 |
class InputData(BaseModel):
|
191 |
webrtc_id: str
|
192 |
chatbot: list[Message]
|
193 |
|
|
|
194 |
app = FastAPI()
|
195 |
stream.mount(app)
|
196 |
|
|
|
197 |
@app.get("/")
|
198 |
async def _():
|
199 |
rtc_config = get_twilio_turn_credentials() if get_space() else None
|
@@ -201,11 +104,13 @@ async def _():
|
|
201 |
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
|
202 |
return HTMLResponse(content=html_content, status_code=200)
|
203 |
|
|
|
204 |
@app.post("/input_hook")
|
205 |
async def _(body: InputData):
|
206 |
stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
|
207 |
return {"status": "ok"}
|
208 |
|
|
|
209 |
@app.get("/outputs")
|
210 |
def _(webrtc_id: str):
|
211 |
async def output_stream():
|
@@ -215,6 +120,7 @@ def _(webrtc_id: str):
|
|
215 |
|
216 |
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
217 |
|
|
|
218 |
if __name__ == "__main__":
|
219 |
import os
|
220 |
|
@@ -224,4 +130,5 @@ if __name__ == "__main__":
|
|
224 |
stream.fastphone(host="0.0.0.0", port=7860)
|
225 |
else:
|
226 |
import uvicorn
|
|
|
227 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
3 |
import time
|
4 |
from pathlib import Path
|
5 |
|
6 |
+
import anthropic
|
7 |
import gradio as gr
|
8 |
import numpy as np
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from elevenlabs import ElevenLabs
|
11 |
from fastapi import FastAPI
|
12 |
from fastapi.responses import HTMLResponse, StreamingResponse
|
13 |
from fastrtc import (
|
14 |
AdditionalOutputs,
|
15 |
ReplyOnPause,
|
16 |
Stream,
|
17 |
+
get_tts_model,
|
18 |
get_twilio_turn_credentials,
|
19 |
)
|
20 |
from fastrtc.utils import audio_to_bytes
|
21 |
from gradio.utils import get_space
|
22 |
+
from groq import Groq
|
23 |
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
load_dotenv()
|
26 |
+
|
27 |
+
groq_client = Groq()
|
28 |
+
claude_client = anthropic.Anthropic()
|
29 |
+
tts_client = ElevenLabs(api_key=os.environ["ELEVENLABS_API_KEY"])
|
30 |
+
|
31 |
curr_dir = Path(__file__).parent
|
32 |
|
33 |
+
tts_model = get_tts_model()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
def response(
|
37 |
audio: tuple[int, np.ndarray],
|
|
|
39 |
):
|
40 |
chatbot = chatbot or []
|
41 |
messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
|
42 |
+
prompt = groq_client.audio.transcriptions.create(
|
43 |
+
file=("audio-file.mp3", audio_to_bytes(audio)),
|
44 |
+
model="whisper-large-v3-turbo",
|
45 |
+
response_format="verbose_json",
|
46 |
+
).text
|
47 |
chatbot.append({"role": "user", "content": prompt})
|
48 |
yield AdditionalOutputs(chatbot)
|
49 |
messages.append({"role": "user", "content": prompt})
|
50 |
+
response = claude_client.messages.create(
|
51 |
+
model="claude-3-5-haiku-20241022",
|
52 |
+
max_tokens=512,
|
53 |
+
messages=messages, # type: ignore
|
54 |
+
)
|
55 |
+
response_text = " ".join(
|
56 |
+
block.text # type: ignore
|
57 |
+
for block in response.content
|
58 |
+
if getattr(block, "type", None) == "text"
|
59 |
+
)
|
60 |
chatbot.append({"role": "assistant", "content": response_text})
|
61 |
+
|
|
|
62 |
start = time.time()
|
63 |
+
|
64 |
print("starting tts", start)
|
65 |
+
for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
|
66 |
+
print("chunk", i, time.time() - start)
|
67 |
+
yield chunk
|
68 |
+
print("finished tts", time.time() - start)
|
69 |
+
yield AdditionalOutputs(chatbot)
|
70 |
+
|
71 |
|
|
|
72 |
chatbot = gr.Chatbot(type="messages")
|
73 |
stream = Stream(
|
74 |
modality="audio",
|
|
|
82 |
time_limit=90 if get_space() else None,
|
83 |
)
|
84 |
|
85 |
+
|
86 |
class Message(BaseModel):
|
87 |
role: str
|
88 |
content: str
|
89 |
|
90 |
+
|
91 |
class InputData(BaseModel):
|
92 |
webrtc_id: str
|
93 |
chatbot: list[Message]
|
94 |
|
95 |
+
|
96 |
app = FastAPI()
|
97 |
stream.mount(app)
|
98 |
|
99 |
+
|
100 |
@app.get("/")
|
101 |
async def _():
|
102 |
rtc_config = get_twilio_turn_credentials() if get_space() else None
|
|
|
104 |
html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
|
105 |
return HTMLResponse(content=html_content, status_code=200)
|
106 |
|
107 |
+
|
108 |
@app.post("/input_hook")
|
109 |
async def _(body: InputData):
|
110 |
stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
|
111 |
return {"status": "ok"}
|
112 |
|
113 |
+
|
114 |
@app.get("/outputs")
|
115 |
def _(webrtc_id: str):
|
116 |
async def output_stream():
|
|
|
120 |
|
121 |
return StreamingResponse(output_stream(), media_type="text/event-stream")
|
122 |
|
123 |
+
|
124 |
if __name__ == "__main__":
|
125 |
import os
|
126 |
|
|
|
130 |
stream.fastphone(host="0.0.0.0", port=7860)
|
131 |
else:
|
132 |
import uvicorn
|
133 |
+
|
134 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
not.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastrtc[vad]
|
2 |
+
gradio>=4.0.0
|
3 |
+
transformers>=4.37.0
|
4 |
+
torch>=2.0.0
|
5 |
+
numpy>=1.24.0
|
6 |
+
fastapi>=0.103.1
|
7 |
+
uvicorn>=0.23.2
|
8 |
+
TTS>=0.17.0
|
9 |
+
pydantic>=2.0.0
|
10 |
+
soundfile>=0.12.1
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ groq
|
|
5 |
anthropic
|
6 |
twilio
|
7 |
python-dotenv
|
|
|
|
5 |
anthropic
|
6 |
twilio
|
7 |
python-dotenv
|
8 |
+
torch>=2.0.0
|