BryanBradfo commited on
Commit
4c65a2b
·
1 Parent(s): c796506
Files changed (3) hide show
  1. app.py +43 -136
  2. not.txt +10 -0
  3. requirements.txt +1 -0
app.py CHANGED
@@ -3,144 +3,35 @@ import os
3
  import time
4
  from pathlib import Path
5
 
 
6
  import gradio as gr
7
  import numpy as np
 
 
8
  from fastapi import FastAPI
9
  from fastapi.responses import HTMLResponse, StreamingResponse
10
  from fastrtc import (
11
  AdditionalOutputs,
12
  ReplyOnPause,
13
  Stream,
 
14
  get_twilio_turn_credentials,
15
  )
16
  from fastrtc.utils import audio_to_bytes
17
  from gradio.utils import get_space
 
18
  from pydantic import BaseModel
19
- import torch
20
-
21
- # Import for open-source models
22
- from transformers import (
23
- AutoModelForSpeechSeq2Seq,
24
- AutoProcessor,
25
- AutoTokenizer,
26
- AutoModelForCausalLM,
27
- pipeline
28
- )
29
- import gc
30
 
31
- # Current directory
 
 
 
 
 
32
  curr_dir = Path(__file__).parent
33
 
34
- # Load ASR model (Whisper small for lightweight usage)
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
37
-
38
- print("Loading ASR model...")
39
- asr_model_id = "openai/whisper-small"
40
- asr_processor = AutoProcessor.from_pretrained(asr_model_id)
41
- asr_model = AutoModelForSpeechSeq2Seq.from_pretrained(
42
- asr_model_id,
43
- torch_dtype=torch_dtype,
44
- low_cpu_mem_usage=True,
45
- use_safetensors=True
46
- )
47
- asr_model.to(device)
48
- asr_pipe = pipeline(
49
- "automatic-speech-recognition",
50
- model=asr_model,
51
- tokenizer=asr_processor.tokenizer,
52
- feature_extractor=asr_processor.feature_extractor,
53
- max_new_tokens=128,
54
- chunk_length_s=30,
55
- batch_size=16,
56
- return_timestamps=False,
57
- device=device,
58
- )
59
 
60
- # Load LLM (TinyLlama for lightweight operation)
61
- print("Loading LLM model...")
62
- llm_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
63
- llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
64
- llm_model = AutoModelForCausalLM.from_pretrained(
65
- llm_model_id,
66
- torch_dtype=torch_dtype,
67
- low_cpu_mem_usage=True,
68
- use_safetensors=True
69
- )
70
- llm_model.to(device)
71
-
72
- # Load TTS model (Piper TTS or CoquiTTS)
73
- print("Loading TTS model...")
74
- from TTS.api import TTS
75
- tts_model = TTS("tts_models/en/ljspeech/tacotron2-DDC", gpu=torch.cuda.is_available())
76
-
77
- # Free up memory after loading models
78
- gc.collect()
79
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
80
-
81
- def transcribe_audio(audio):
82
- """Transcribe audio using Whisper model"""
83
- result = asr_pipe({"array": audio[1], "sampling_rate": audio[0]})
84
- return result["text"]
85
-
86
- def generate_llm_response(messages):
87
- """Generate response using TinyLlama model"""
88
- # Format messages for TinyLlama
89
- prompt = ""
90
- for msg in messages:
91
- if msg["role"] == "user":
92
- prompt += f"<|user|>\n{msg['content']}\n<|assistant|>\n"
93
- elif msg["role"] == "assistant":
94
- prompt += f"{msg['content']}\n"
95
-
96
- # Add final assistant token if not present
97
- if not prompt.endswith("<|assistant|>\n"):
98
- prompt += "<|assistant|>\n"
99
-
100
- # Generate response
101
- inputs = llm_tokenizer(prompt, return_tensors="pt").to(device)
102
- outputs = llm_model.generate(
103
- inputs.input_ids,
104
- max_new_tokens=512,
105
- do_sample=True,
106
- temperature=0.7,
107
- top_p=0.9,
108
- )
109
- response = llm_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
110
- return response
111
-
112
- def generate_speech(text):
113
- """Generate speech from text using TTS model"""
114
- # Create a temporary file path
115
- output_path = curr_dir / "temp_audio.wav"
116
-
117
- # Generate audio
118
- tts_model.tts_to_file(text=text, file_path=str(output_path))
119
-
120
- # Read audio file
121
- import wave
122
- import numpy as np
123
-
124
- with wave.open(str(output_path), 'rb') as wav_file:
125
- # Get audio parameters
126
- sample_rate = wav_file.getframerate()
127
- n_frames = wav_file.getnframes()
128
- n_channels = wav_file.getnchannels()
129
-
130
- # Read audio data
131
- data = wav_file.readframes(n_frames)
132
-
133
- # Convert to numpy array
134
- audio_data = np.frombuffer(data, dtype=np.int16)
135
- if n_channels == 2: # Convert stereo to mono
136
- audio_data = audio_data.reshape(-1, 2).mean(axis=1).astype(np.int16)
137
-
138
- # Clean up temp file
139
- if output_path.exists():
140
- os.remove(output_path)
141
-
142
- # Return audio data and sample rate
143
- return (sample_rate, audio_data)
144
 
145
  def response(
146
  audio: tuple[int, np.ndarray],
@@ -148,28 +39,36 @@ def response(
148
  ):
149
  chatbot = chatbot or []
150
  messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
151
-
152
- # Transcribe audio
153
- prompt = transcribe_audio(audio)
 
 
154
  chatbot.append({"role": "user", "content": prompt})
155
  yield AdditionalOutputs(chatbot)
156
  messages.append({"role": "user", "content": prompt})
157
-
158
- # Generate response
159
- response_text = generate_llm_response(messages)
 
 
 
 
 
 
 
160
  chatbot.append({"role": "assistant", "content": response_text})
161
-
162
- # Generate speech
163
  start = time.time()
 
164
  print("starting tts", start)
165
-
166
- # Generate speech in a single call for simplicity
167
- audio_output = generate_speech(response_text)
168
- print("finished tts", time.time() - start)
169
- yield audio_output
170
- yield AdditionalOutputs(chatbot)
171
 
172
- # Set up Gradio chatbot interface
173
  chatbot = gr.Chatbot(type="messages")
174
  stream = Stream(
175
  modality="audio",
@@ -183,17 +82,21 @@ stream = Stream(
183
  time_limit=90 if get_space() else None,
184
  )
185
 
 
186
  class Message(BaseModel):
187
  role: str
188
  content: str
189
 
 
190
  class InputData(BaseModel):
191
  webrtc_id: str
192
  chatbot: list[Message]
193
 
 
194
  app = FastAPI()
195
  stream.mount(app)
196
 
 
197
  @app.get("/")
198
  async def _():
199
  rtc_config = get_twilio_turn_credentials() if get_space() else None
@@ -201,11 +104,13 @@ async def _():
201
  html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
202
  return HTMLResponse(content=html_content, status_code=200)
203
 
 
204
  @app.post("/input_hook")
205
  async def _(body: InputData):
206
  stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
207
  return {"status": "ok"}
208
 
 
209
  @app.get("/outputs")
210
  def _(webrtc_id: str):
211
  async def output_stream():
@@ -215,6 +120,7 @@ def _(webrtc_id: str):
215
 
216
  return StreamingResponse(output_stream(), media_type="text/event-stream")
217
 
 
218
  if __name__ == "__main__":
219
  import os
220
 
@@ -224,4 +130,5 @@ if __name__ == "__main__":
224
  stream.fastphone(host="0.0.0.0", port=7860)
225
  else:
226
  import uvicorn
 
227
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  import time
4
  from pathlib import Path
5
 
6
+ import anthropic
7
  import gradio as gr
8
  import numpy as np
9
+ from dotenv import load_dotenv
10
+ from elevenlabs import ElevenLabs
11
  from fastapi import FastAPI
12
  from fastapi.responses import HTMLResponse, StreamingResponse
13
  from fastrtc import (
14
  AdditionalOutputs,
15
  ReplyOnPause,
16
  Stream,
17
+ get_tts_model,
18
  get_twilio_turn_credentials,
19
  )
20
  from fastrtc.utils import audio_to_bytes
21
  from gradio.utils import get_space
22
+ from groq import Groq
23
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ load_dotenv()
26
+
27
+ groq_client = Groq()
28
+ claude_client = anthropic.Anthropic()
29
+ tts_client = ElevenLabs(api_key=os.environ["ELEVENLABS_API_KEY"])
30
+
31
  curr_dir = Path(__file__).parent
32
 
33
+ tts_model = get_tts_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  def response(
37
  audio: tuple[int, np.ndarray],
 
39
  ):
40
  chatbot = chatbot or []
41
  messages = [{"role": d["role"], "content": d["content"]} for d in chatbot]
42
+ prompt = groq_client.audio.transcriptions.create(
43
+ file=("audio-file.mp3", audio_to_bytes(audio)),
44
+ model="whisper-large-v3-turbo",
45
+ response_format="verbose_json",
46
+ ).text
47
  chatbot.append({"role": "user", "content": prompt})
48
  yield AdditionalOutputs(chatbot)
49
  messages.append({"role": "user", "content": prompt})
50
+ response = claude_client.messages.create(
51
+ model="claude-3-5-haiku-20241022",
52
+ max_tokens=512,
53
+ messages=messages, # type: ignore
54
+ )
55
+ response_text = " ".join(
56
+ block.text # type: ignore
57
+ for block in response.content
58
+ if getattr(block, "type", None) == "text"
59
+ )
60
  chatbot.append({"role": "assistant", "content": response_text})
61
+
 
62
  start = time.time()
63
+
64
  print("starting tts", start)
65
+ for i, chunk in enumerate(tts_model.stream_tts_sync(response_text)):
66
+ print("chunk", i, time.time() - start)
67
+ yield chunk
68
+ print("finished tts", time.time() - start)
69
+ yield AdditionalOutputs(chatbot)
70
+
71
 
 
72
  chatbot = gr.Chatbot(type="messages")
73
  stream = Stream(
74
  modality="audio",
 
82
  time_limit=90 if get_space() else None,
83
  )
84
 
85
+
86
  class Message(BaseModel):
87
  role: str
88
  content: str
89
 
90
+
91
  class InputData(BaseModel):
92
  webrtc_id: str
93
  chatbot: list[Message]
94
 
95
+
96
  app = FastAPI()
97
  stream.mount(app)
98
 
99
+
100
  @app.get("/")
101
  async def _():
102
  rtc_config = get_twilio_turn_credentials() if get_space() else None
 
104
  html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
105
  return HTMLResponse(content=html_content, status_code=200)
106
 
107
+
108
  @app.post("/input_hook")
109
  async def _(body: InputData):
110
  stream.set_input(body.webrtc_id, body.model_dump()["chatbot"])
111
  return {"status": "ok"}
112
 
113
+
114
  @app.get("/outputs")
115
  def _(webrtc_id: str):
116
  async def output_stream():
 
120
 
121
  return StreamingResponse(output_stream(), media_type="text/event-stream")
122
 
123
+
124
  if __name__ == "__main__":
125
  import os
126
 
 
130
  stream.fastphone(host="0.0.0.0", port=7860)
131
  else:
132
  import uvicorn
133
+
134
  uvicorn.run(app, host="0.0.0.0", port=7860)
not.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastrtc[vad]
2
+ gradio>=4.0.0
3
+ transformers>=4.37.0
4
+ torch>=2.0.0
5
+ numpy>=1.24.0
6
+ fastapi>=0.103.1
7
+ uvicorn>=0.23.2
8
+ TTS>=0.17.0
9
+ pydantic>=2.0.0
10
+ soundfile>=0.12.1
requirements.txt CHANGED
@@ -5,3 +5,4 @@ groq
5
  anthropic
6
  twilio
7
  python-dotenv
 
 
5
  anthropic
6
  twilio
7
  python-dotenv
8
+ torch>=2.0.0