Spaces:
Running
Running
Nirav Madhani
commited on
Commit
·
a1096d7
1
Parent(s):
9b731f8
Front end function call
Browse files- .gitignore +2 -0
- handler.py +25 -12
- index.html +104 -57
- webapp.py +19 -12
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__/*
|
handler.py
CHANGED
@@ -6,11 +6,18 @@ import os
|
|
6 |
import traceback
|
7 |
from websockets.asyncio.client import connect
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
host = "generativelanguage.googleapis.com"
|
10 |
model = "gemini-2.0-flash-exp"
|
11 |
api_key = os.environ["GOOGLE_API_KEY"]
|
12 |
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"
|
13 |
|
|
|
|
|
14 |
class AudioLoop:
|
15 |
def __init__(self):
|
16 |
self.ws = None
|
@@ -54,33 +61,39 @@ class AudioLoop:
|
|
54 |
response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
|
55 |
)
|
56 |
pcm_data = base64.b64decode(b64data)
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
except KeyError:
|
59 |
# No audio in this message
|
60 |
pass
|
61 |
|
|
|
62 |
tool_call = response.pop('toolCall', None)
|
63 |
if tool_call is not None:
|
64 |
-
await self.
|
|
|
|
|
|
|
65 |
|
66 |
# If "turnComplete" is present
|
67 |
if "serverContent" in response and response["serverContent"].get("turnComplete"):
|
68 |
print("[AudioLoop] Gemini turn complete")
|
69 |
|
70 |
-
async def handle_tool_call(self,
|
71 |
-
|
72 |
-
|
73 |
-
msg = {
|
74 |
'tool_response': {
|
75 |
'function_responses': [{
|
76 |
-
'id':
|
77 |
-
'name':
|
78 |
-
'response':
|
79 |
}]
|
80 |
-
}
|
81 |
}
|
82 |
-
|
83 |
-
|
84 |
|
85 |
async def run(self):
|
86 |
"""Main entry point: connects to Gemini, starts send/receive tasks."""
|
|
|
6 |
import traceback
|
7 |
from websockets.asyncio.client import connect
|
8 |
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
|
11 |
+
# Load environment variables from a .env file
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
host = "generativelanguage.googleapis.com"
|
15 |
model = "gemini-2.0-flash-exp"
|
16 |
api_key = os.environ["GOOGLE_API_KEY"]
|
17 |
uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"
|
18 |
|
19 |
+
|
20 |
+
|
21 |
class AudioLoop:
|
22 |
def __init__(self):
|
23 |
self.ws = None
|
|
|
61 |
response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
|
62 |
)
|
63 |
pcm_data = base64.b64decode(b64data)
|
64 |
+
# Send audio with type "audio"
|
65 |
+
await self.audio_in_queue.put({
|
66 |
+
"type": "audio",
|
67 |
+
"payload": base64.b64encode(pcm_data).decode()
|
68 |
+
})
|
69 |
except KeyError:
|
70 |
# No audio in this message
|
71 |
pass
|
72 |
|
73 |
+
# Forward function calls to client
|
74 |
tool_call = response.pop('toolCall', None)
|
75 |
if tool_call is not None:
|
76 |
+
await self.audio_in_queue.put({
|
77 |
+
"type": "function_call",
|
78 |
+
"payload": tool_call
|
79 |
+
})
|
80 |
|
81 |
# If "turnComplete" is present
|
82 |
if "serverContent" in response and response["serverContent"].get("turnComplete"):
|
83 |
print("[AudioLoop] Gemini turn complete")
|
84 |
|
85 |
+
async def handle_tool_call(self, tool_call_response):
|
86 |
+
"""Handle tool call response from client"""
|
87 |
+
msg = {
|
|
|
88 |
'tool_response': {
|
89 |
'function_responses': [{
|
90 |
+
'id': tool_call_response['id'],
|
91 |
+
'name': tool_call_response['name'],
|
92 |
+
'response': tool_call_response['response']
|
93 |
}]
|
|
|
94 |
}
|
95 |
+
}
|
96 |
+
await self.ws.send(json.dumps(msg))
|
97 |
|
98 |
async def run(self):
|
99 |
"""Main entry point: connects to Gemini, starts send/receive tasks."""
|
index.html
CHANGED
@@ -144,6 +144,7 @@
|
|
144 |
<label><input type="checkbox" id="logWebSocket"> WebSocket Events</label>
|
145 |
<label style="margin-left: 1em"><input type="checkbox" id="logAudio"> Audio Events</label>
|
146 |
<label style="margin-left: 1em"><input type="checkbox" id="logText"> Text Events</label>
|
|
|
147 |
<label style="margin-left: 1em"><input type="checkbox" id="logError" checked> Error Events</label>
|
148 |
</div>
|
149 |
|
@@ -211,14 +212,21 @@
|
|
211 |
}
|
212 |
|
213 |
function logMessage(category, ...args) {
|
214 |
-
const
|
215 |
-
const
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
219 |
const timestamp = new Date().toLocaleTimeString();
|
220 |
-
|
221 |
-
|
|
|
|
|
222 |
}
|
223 |
}
|
224 |
|
@@ -311,54 +319,49 @@
|
|
311 |
}
|
312 |
|
313 |
function connectWebSocket() {
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
// Use current origin and replace http(s) with ws(s)
|
318 |
-
const wsUrl = `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`;
|
319 |
-
socket = new WebSocket(wsUrl);
|
320 |
-
|
321 |
-
socket.onopen = () => {
|
322 |
-
logMessage("WebSocket", "Opened connection");
|
323 |
-
updateConnectionStatus(true);
|
324 |
-
if (!playbackCtx) {
|
325 |
-
playbackCtx = new (window.AudioContext || window.webkitAudioContext)();
|
326 |
-
setupVisualizer();
|
327 |
-
}
|
328 |
-
nextPlaybackTime = playbackCtx.currentTime;
|
329 |
-
};
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
stopCapture();
|
341 |
-
}
|
342 |
-
};
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
const
|
347 |
-
|
348 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
const int16View = new Int16Array(arrayBuffer);
|
350 |
const float32Buffer = new Float32Array(int16View.length);
|
351 |
for (let i = 0; i < int16View.length; i++) {
|
352 |
float32Buffer[i] = int16View[i] / 32768;
|
353 |
}
|
354 |
-
|
|
|
|
|
355 |
const audioBuffer = playbackCtx.createBuffer(1, float32Buffer.length, sampleRate);
|
356 |
audioBuffer.copyToChannel(float32Buffer, 0);
|
357 |
-
|
358 |
const source = playbackCtx.createBufferSource();
|
359 |
source.buffer = audioBuffer;
|
360 |
|
361 |
-
// Connect through analyser for visualization
|
362 |
if (analyser) {
|
363 |
source.connect(analyser);
|
364 |
analyser.connect(playbackCtx.destination);
|
@@ -368,32 +371,76 @@
|
|
368 |
} else {
|
369 |
source.connect(playbackCtx.destination);
|
370 |
}
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
374 |
scheduledSources.push(source);
|
375 |
-
|
|
|
376 |
source.onended = () => {
|
377 |
const index = scheduledSources.indexOf(source);
|
378 |
if (index > -1) {
|
379 |
scheduledSources.splice(index, 1);
|
380 |
}
|
381 |
-
// Stop visualizer if no more audio
|
382 |
if (scheduledSources.length === 0) {
|
383 |
stopVisualizer();
|
384 |
}
|
385 |
};
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
logMessage(
|
390 |
-
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
}
|
393 |
-
}
|
394 |
-
|
395 |
-
|
396 |
-
|
|
|
|
|
|
|
|
|
397 |
}
|
398 |
|
399 |
async function startCapture() {
|
|
|
144 |
<label><input type="checkbox" id="logWebSocket"> WebSocket Events</label>
|
145 |
<label style="margin-left: 1em"><input type="checkbox" id="logAudio"> Audio Events</label>
|
146 |
<label style="margin-left: 1em"><input type="checkbox" id="logText"> Text Events</label>
|
147 |
+
<label style="margin-left: 1em"><input type="checkbox" id="logFunction" checked> Function Events</label>
|
148 |
<label style="margin-left: 1em"><input type="checkbox" id="logError" checked> Error Events</label>
|
149 |
</div>
|
150 |
|
|
|
212 |
}
|
213 |
|
214 |
function logMessage(category, ...args) {
|
215 |
+
const logElement = document.getElementById('log');
|
216 |
+
const shouldLog = {
|
217 |
+
'websocket': document.getElementById('logWebSocket').checked,
|
218 |
+
'audio': document.getElementById('logAudio').checked,
|
219 |
+
'text': document.getElementById('logText').checked,
|
220 |
+
'function': document.getElementById('logFunction').checked,
|
221 |
+
'error': document.getElementById('logError').checked
|
222 |
+
};
|
223 |
+
|
224 |
+
if (shouldLog[category]) {
|
225 |
const timestamp = new Date().toLocaleTimeString();
|
226 |
+
const message = `[${timestamp}] [${category}] ${args.map(arg =>
|
227 |
+
typeof arg === 'object' ? JSON.stringify(arg, null, 2) : arg
|
228 |
+
).join(' ')}`;
|
229 |
+
logElement.textContent = message + '\n' + logElement.textContent;
|
230 |
}
|
231 |
}
|
232 |
|
|
|
319 |
}
|
320 |
|
321 |
function connectWebSocket() {
|
322 |
+
try {
|
323 |
+
socket = new WebSocket(`ws://${window.location.host}/ws`);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
+
socket.onopen = () => {
|
326 |
+
logMessage('websocket', 'Connected to server');
|
327 |
+
updateConnectionStatus(true);
|
328 |
+
};
|
329 |
|
330 |
+
socket.onclose = () => {
|
331 |
+
logMessage('websocket', 'Disconnected from server');
|
332 |
+
updateConnectionStatus(false);
|
333 |
+
};
|
|
|
|
|
|
|
334 |
|
335 |
+
socket.onmessage = async (event) => {
|
336 |
+
const message = JSON.parse(event.data);
|
337 |
+
const messageType = message.type;
|
338 |
+
|
339 |
+
if (messageType === 'audio') {
|
340 |
+
// Handle audio data
|
341 |
+
logMessage('audio', 'Received audio chunk from server');
|
342 |
+
const arrayBuffer = base64ToArrayBuffer(message.payload);
|
343 |
+
|
344 |
+
if (!playbackCtx) {
|
345 |
+
playbackCtx = new (window.AudioContext || window.webkitAudioContext)();
|
346 |
+
setupVisualizer();
|
347 |
+
}
|
348 |
+
|
349 |
+
// Convert Int16 PCM to Float32
|
350 |
const int16View = new Int16Array(arrayBuffer);
|
351 |
const float32Buffer = new Float32Array(int16View.length);
|
352 |
for (let i = 0; i < int16View.length; i++) {
|
353 |
float32Buffer[i] = int16View[i] / 32768;
|
354 |
}
|
355 |
+
|
356 |
+
// Create audio buffer with correct sample rate
|
357 |
+
const sampleRate = 24000; // Sample rate from server
|
358 |
const audioBuffer = playbackCtx.createBuffer(1, float32Buffer.length, sampleRate);
|
359 |
audioBuffer.copyToChannel(float32Buffer, 0);
|
360 |
+
|
361 |
const source = playbackCtx.createBufferSource();
|
362 |
source.buffer = audioBuffer;
|
363 |
|
364 |
+
// Connect through analyser for visualization if available
|
365 |
if (analyser) {
|
366 |
source.connect(analyser);
|
367 |
analyser.connect(playbackCtx.destination);
|
|
|
371 |
} else {
|
372 |
source.connect(playbackCtx.destination);
|
373 |
}
|
374 |
+
|
375 |
+
// Schedule the audio to play at the right time
|
376 |
+
const startTime = Math.max(nextPlaybackTime, playbackCtx.currentTime);
|
377 |
+
source.start(startTime);
|
378 |
+
nextPlaybackTime = startTime + audioBuffer.duration;
|
379 |
+
|
380 |
+
// Keep track of scheduled sources
|
381 |
scheduledSources.push(source);
|
382 |
+
|
383 |
+
// Clean up source when it finishes playing
|
384 |
source.onended = () => {
|
385 |
const index = scheduledSources.indexOf(source);
|
386 |
if (index > -1) {
|
387 |
scheduledSources.splice(index, 1);
|
388 |
}
|
389 |
+
// Stop visualizer if no more audio playing
|
390 |
if (scheduledSources.length === 0) {
|
391 |
stopVisualizer();
|
392 |
}
|
393 |
};
|
394 |
+
}
|
395 |
+
else if (messageType === 'function_call') {
|
396 |
+
// Handle function calls from server
|
397 |
+
logMessage('function', 'Received function call:', message.payload);
|
398 |
+
const functionCalls = message.payload.functionCalls;
|
399 |
+
|
400 |
+
for (const fc of functionCalls) {
|
401 |
+
const functionName = fc.name;
|
402 |
+
const functionId = fc.id;
|
403 |
+
|
404 |
+
// Handle different functions
|
405 |
+
let result = 'ok';
|
406 |
+
|
407 |
+
if (functionName === 'turn_on_the_lights') {
|
408 |
+
logMessage('function', 'Turning on the lights');
|
409 |
+
// Simulate turning on lights
|
410 |
+
result = 'Lights turned on successfully';
|
411 |
+
}
|
412 |
+
else if (functionName === 'turn_off_the_lights') {
|
413 |
+
logMessage('function', 'Turning off the lights');
|
414 |
+
// Simulate turning off lights
|
415 |
+
result = 'Lights turned off successfully';
|
416 |
+
}
|
417 |
+
|
418 |
+
// Send response back to server
|
419 |
+
const response = {
|
420 |
+
type: 'tool_call_response',
|
421 |
+
payload: {
|
422 |
+
id: functionId,
|
423 |
+
name: functionName,
|
424 |
+
response: {
|
425 |
+
result: {
|
426 |
+
string_value: result
|
427 |
+
}
|
428 |
+
}
|
429 |
+
}
|
430 |
+
};
|
431 |
+
|
432 |
+
socket.send(JSON.stringify(response));
|
433 |
+
logMessage('function', 'Sent function response:', response);
|
434 |
+
}
|
435 |
}
|
436 |
+
};
|
437 |
+
|
438 |
+
socket.onerror = (error) => {
|
439 |
+
logMessage('error', 'WebSocket error:', error);
|
440 |
+
};
|
441 |
+
} catch (error) {
|
442 |
+
logMessage('error', 'Failed to connect:', error);
|
443 |
+
}
|
444 |
}
|
445 |
|
446 |
async function startCapture() {
|
webapp.py
CHANGED
@@ -53,6 +53,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
53 |
|
54 |
# Handle audio data from client
|
55 |
if msg_type == "audio":
|
|
|
56 |
raw_pcm = base64.b64decode(msg["payload"])
|
57 |
forward_msg = {
|
58 |
"realtime_input": {
|
@@ -97,29 +98,35 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
97 |
}
|
98 |
await audio_loop.out_queue.put(forward_msg)
|
99 |
|
|
|
|
|
|
|
|
|
100 |
else:
|
101 |
print("[from_client_to_gemini] Unknown message type:", msg_type)
|
102 |
|
103 |
except WebSocketDisconnect:
|
104 |
print("[from_client_to_gemini] Client disconnected.")
|
105 |
-
#del audio_loop
|
106 |
loop_task.cancel()
|
107 |
except Exception as e:
|
108 |
print("[from_client_to_gemini] Error:", e)
|
109 |
|
110 |
async def from_gemini_to_client():
|
111 |
-
"""Reads
|
112 |
try:
|
113 |
while True:
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
123 |
|
124 |
except WebSocketDisconnect:
|
125 |
print("[from_gemini_to_client] Client disconnected.")
|
@@ -143,5 +150,5 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
143 |
pass
|
144 |
print("[websocket_endpoint] Cleaned up AudioLoop for client")
|
145 |
|
146 |
-
if __name__ == "__main__":
|
147 |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
53 |
|
54 |
# Handle audio data from client
|
55 |
if msg_type == "audio":
|
56 |
+
# Decode base64 audio from client
|
57 |
raw_pcm = base64.b64decode(msg["payload"])
|
58 |
forward_msg = {
|
59 |
"realtime_input": {
|
|
|
98 |
}
|
99 |
await audio_loop.out_queue.put(forward_msg)
|
100 |
|
101 |
+
elif msg_type == "tool_call_response":
|
102 |
+
# Handle tool call response from client
|
103 |
+
await audio_loop.handle_tool_call(msg["payload"])
|
104 |
+
|
105 |
else:
|
106 |
print("[from_client_to_gemini] Unknown message type:", msg_type)
|
107 |
|
108 |
except WebSocketDisconnect:
|
109 |
print("[from_client_to_gemini] Client disconnected.")
|
|
|
110 |
loop_task.cancel()
|
111 |
except Exception as e:
|
112 |
print("[from_client_to_gemini] Error:", e)
|
113 |
|
114 |
async def from_gemini_to_client():
|
115 |
+
"""Reads messages from Gemini and sends them back to the client."""
|
116 |
try:
|
117 |
while True:
|
118 |
+
message = await audio_loop.audio_in_queue.get()
|
119 |
+
message_type = message["type"]
|
120 |
+
|
121 |
+
if message_type == "audio":
|
122 |
+
# Audio data is already base64 encoded from handler.py
|
123 |
+
await websocket.send_text(json.dumps(message))
|
124 |
+
print("[from_gemini_to_client] Sending audio chunk to client")
|
125 |
+
|
126 |
+
elif message_type == "function_call":
|
127 |
+
# Forward function call to client
|
128 |
+
await websocket.send_text(json.dumps(message))
|
129 |
+
print("[from_gemini_to_client] Forwarding function call to client")
|
130 |
|
131 |
except WebSocketDisconnect:
|
132 |
print("[from_gemini_to_client] Client disconnected.")
|
|
|
150 |
pass
|
151 |
print("[websocket_endpoint] Cleaned up AudioLoop for client")
|
152 |
|
153 |
+
if __name__ == "__main__":
|
154 |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
|