Nirav Madhani commited on
Commit
a1096d7
·
1 Parent(s): 9b731f8

Front end function call

Browse files
Files changed (4) hide show
  1. .gitignore +2 -0
  2. handler.py +25 -12
  3. index.html +104 -57
  4. webapp.py +19 -12
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .env
2
+ __pycache__/*
handler.py CHANGED
@@ -6,11 +6,18 @@ import os
6
  import traceback
7
  from websockets.asyncio.client import connect
8
 
 
 
 
 
 
9
  host = "generativelanguage.googleapis.com"
10
  model = "gemini-2.0-flash-exp"
11
  api_key = os.environ["GOOGLE_API_KEY"]
12
  uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"
13
 
 
 
14
  class AudioLoop:
15
  def __init__(self):
16
  self.ws = None
@@ -54,33 +61,39 @@ class AudioLoop:
54
  response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
55
  )
56
  pcm_data = base64.b64decode(b64data)
57
- await self.audio_in_queue.put(pcm_data)
 
 
 
 
58
  except KeyError:
59
  # No audio in this message
60
  pass
61
 
 
62
  tool_call = response.pop('toolCall', None)
63
  if tool_call is not None:
64
- await self.handle_tool_call(tool_call)
 
 
 
65
 
66
  # If "turnComplete" is present
67
  if "serverContent" in response and response["serverContent"].get("turnComplete"):
68
  print("[AudioLoop] Gemini turn complete")
69
 
70
- async def handle_tool_call(self,tool_call):
71
- print(" ", tool_call)
72
- for fc in tool_call['functionCalls']:
73
- msg = {
74
  'tool_response': {
75
  'function_responses': [{
76
- 'id': fc['id'],
77
- 'name': fc['name'],
78
- 'response':{'result': {'string_value': 'ok'}}
79
  }]
80
- }
81
  }
82
- print('>>> ', msg)
83
- await self.ws.send(json.dumps(msg))
84
 
85
  async def run(self):
86
  """Main entry point: connects to Gemini, starts send/receive tasks."""
 
6
  import traceback
7
  from websockets.asyncio.client import connect
8
 
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from a .env file
12
+ load_dotenv()
13
+
14
  host = "generativelanguage.googleapis.com"
15
  model = "gemini-2.0-flash-exp"
16
  api_key = os.environ["GOOGLE_API_KEY"]
17
  uri = f"wss://{host}/ws/google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent?key={api_key}"
18
 
19
+
20
+
21
  class AudioLoop:
22
  def __init__(self):
23
  self.ws = None
 
61
  response["serverContent"]["modelTurn"]["parts"][0]["inlineData"]["data"]
62
  )
63
  pcm_data = base64.b64decode(b64data)
64
+ # Send audio with type "audio"
65
+ await self.audio_in_queue.put({
66
+ "type": "audio",
67
+ "payload": base64.b64encode(pcm_data).decode()
68
+ })
69
  except KeyError:
70
  # No audio in this message
71
  pass
72
 
73
+ # Forward function calls to client
74
  tool_call = response.pop('toolCall', None)
75
  if tool_call is not None:
76
+ await self.audio_in_queue.put({
77
+ "type": "function_call",
78
+ "payload": tool_call
79
+ })
80
 
81
  # If "turnComplete" is present
82
  if "serverContent" in response and response["serverContent"].get("turnComplete"):
83
  print("[AudioLoop] Gemini turn complete")
84
 
85
+ async def handle_tool_call(self, tool_call_response):
86
+ """Handle tool call response from client"""
87
+ msg = {
 
88
  'tool_response': {
89
  'function_responses': [{
90
+ 'id': tool_call_response['id'],
91
+ 'name': tool_call_response['name'],
92
+ 'response': tool_call_response['response']
93
  }]
 
94
  }
95
+ }
96
+ await self.ws.send(json.dumps(msg))
97
 
98
  async def run(self):
99
  """Main entry point: connects to Gemini, starts send/receive tasks."""
index.html CHANGED
@@ -144,6 +144,7 @@
144
  <label><input type="checkbox" id="logWebSocket"> WebSocket Events</label>
145
  <label style="margin-left: 1em"><input type="checkbox" id="logAudio"> Audio Events</label>
146
  <label style="margin-left: 1em"><input type="checkbox" id="logText"> Text Events</label>
 
147
  <label style="margin-left: 1em"><input type="checkbox" id="logError" checked> Error Events</label>
148
  </div>
149
 
@@ -211,14 +212,21 @@
211
  }
212
 
213
  function logMessage(category, ...args) {
214
- const pre = document.getElementById("log");
215
- const logCategory = document.getElementById(`log${category.charAt(0).toUpperCase() + category.slice(1)}`);
216
- const shouldLog = logCategory ? logCategory.checked : false;
217
-
218
- if (shouldLog) {
 
 
 
 
 
219
  const timestamp = new Date().toLocaleTimeString();
220
- pre.textContent += `[${timestamp}] [${category}] ` + args.join(" ") + "\n";
221
- console.log(`[${category}]`, ...args);
 
 
222
  }
223
  }
224
 
@@ -311,54 +319,49 @@
311
  }
312
 
313
  function connectWebSocket() {
314
- logMessage("WebSocket", "Connecting...");
315
- updateConnectionStatus(false);
316
-
317
- // Use current origin and replace http(s) with ws(s)
318
- const wsUrl = `${window.location.protocol === 'https:' ? 'wss:' : 'ws:'}//${window.location.host}/ws`;
319
- socket = new WebSocket(wsUrl);
320
-
321
- socket.onopen = () => {
322
- logMessage("WebSocket", "Opened connection");
323
- updateConnectionStatus(true);
324
- if (!playbackCtx) {
325
- playbackCtx = new (window.AudioContext || window.webkitAudioContext)();
326
- setupVisualizer();
327
- }
328
- nextPlaybackTime = playbackCtx.currentTime;
329
- };
330
 
331
- socket.onerror = (err) => {
332
- logMessage("Error", "WebSocket error:", err);
333
- updateConnectionStatus(false);
334
- };
335
 
336
- socket.onclose = () => {
337
- logMessage("WebSocket", "Connection closed");
338
- updateConnectionStatus(false);
339
- if (isCapturing) {
340
- stopCapture();
341
- }
342
- };
343
 
344
- socket.onmessage = (event) => {
345
- try {
346
- const data = JSON.parse(event.data);
347
- if (data.type === "audio" && data.payload) {
348
- const arrayBuffer = base64ToArrayBuffer(data.payload);
 
 
 
 
 
 
 
 
 
 
349
  const int16View = new Int16Array(arrayBuffer);
350
  const float32Buffer = new Float32Array(int16View.length);
351
  for (let i = 0; i < int16View.length; i++) {
352
  float32Buffer[i] = int16View[i] / 32768;
353
  }
354
- const sampleRate = 24000; // RECEIVED_SAMPLE_RATE from app.py
 
 
355
  const audioBuffer = playbackCtx.createBuffer(1, float32Buffer.length, sampleRate);
356
  audioBuffer.copyToChannel(float32Buffer, 0);
357
- let scheduledTime = playbackCtx.currentTime > nextPlaybackTime ? playbackCtx.currentTime : nextPlaybackTime;
358
  const source = playbackCtx.createBufferSource();
359
  source.buffer = audioBuffer;
360
 
361
- // Connect through analyser for visualization
362
  if (analyser) {
363
  source.connect(analyser);
364
  analyser.connect(playbackCtx.destination);
@@ -368,32 +371,76 @@
368
  } else {
369
  source.connect(playbackCtx.destination);
370
  }
371
-
372
- source.start(scheduledTime);
373
- // Add source to tracked sources
 
 
 
 
374
  scheduledSources.push(source);
375
- // Remove source from tracking once it finishes
 
376
  source.onended = () => {
377
  const index = scheduledSources.indexOf(source);
378
  if (index > -1) {
379
  scheduledSources.splice(index, 1);
380
  }
381
- // Stop visualizer if no more audio
382
  if (scheduledSources.length === 0) {
383
  stopVisualizer();
384
  }
385
  };
386
- nextPlaybackTime = scheduledTime + audioBuffer.duration;
387
- logMessage("Audio", "Scheduled playback. Start time:", scheduledTime, "Duration:", audioBuffer.duration);
388
- } else if (data.type === "text" && data.content) {
389
- logMessage("Text", "Received:", data.content);
390
- } else {
391
- logMessage("WebSocket", "Received message:", event.data);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  }
393
- } catch (err) {
394
- logMessage("Error", "Failed to process message:", err);
395
- }
396
- };
 
 
 
 
397
  }
398
 
399
  async function startCapture() {
 
144
  <label><input type="checkbox" id="logWebSocket"> WebSocket Events</label>
145
  <label style="margin-left: 1em"><input type="checkbox" id="logAudio"> Audio Events</label>
146
  <label style="margin-left: 1em"><input type="checkbox" id="logText"> Text Events</label>
147
+ <label style="margin-left: 1em"><input type="checkbox" id="logFunction" checked> Function Events</label>
148
  <label style="margin-left: 1em"><input type="checkbox" id="logError" checked> Error Events</label>
149
  </div>
150
 
 
212
  }
213
 
214
  function logMessage(category, ...args) {
215
+ const logElement = document.getElementById('log');
216
+ const shouldLog = {
217
+ 'websocket': document.getElementById('logWebSocket').checked,
218
+ 'audio': document.getElementById('logAudio').checked,
219
+ 'text': document.getElementById('logText').checked,
220
+ 'function': document.getElementById('logFunction').checked,
221
+ 'error': document.getElementById('logError').checked
222
+ };
223
+
224
+ if (shouldLog[category]) {
225
  const timestamp = new Date().toLocaleTimeString();
226
+ const message = `[${timestamp}] [${category}] ${args.map(arg =>
227
+ typeof arg === 'object' ? JSON.stringify(arg, null, 2) : arg
228
+ ).join(' ')}`;
229
+ logElement.textContent = message + '\n' + logElement.textContent;
230
  }
231
  }
232
 
 
319
  }
320
 
321
  function connectWebSocket() {
322
+ try {
323
+ socket = new WebSocket(`ws://${window.location.host}/ws`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ socket.onopen = () => {
326
+ logMessage('websocket', 'Connected to server');
327
+ updateConnectionStatus(true);
328
+ };
329
 
330
+ socket.onclose = () => {
331
+ logMessage('websocket', 'Disconnected from server');
332
+ updateConnectionStatus(false);
333
+ };
 
 
 
334
 
335
+ socket.onmessage = async (event) => {
336
+ const message = JSON.parse(event.data);
337
+ const messageType = message.type;
338
+
339
+ if (messageType === 'audio') {
340
+ // Handle audio data
341
+ logMessage('audio', 'Received audio chunk from server');
342
+ const arrayBuffer = base64ToArrayBuffer(message.payload);
343
+
344
+ if (!playbackCtx) {
345
+ playbackCtx = new (window.AudioContext || window.webkitAudioContext)();
346
+ setupVisualizer();
347
+ }
348
+
349
+ // Convert Int16 PCM to Float32
350
  const int16View = new Int16Array(arrayBuffer);
351
  const float32Buffer = new Float32Array(int16View.length);
352
  for (let i = 0; i < int16View.length; i++) {
353
  float32Buffer[i] = int16View[i] / 32768;
354
  }
355
+
356
+ // Create audio buffer with correct sample rate
357
+ const sampleRate = 24000; // Sample rate from server
358
  const audioBuffer = playbackCtx.createBuffer(1, float32Buffer.length, sampleRate);
359
  audioBuffer.copyToChannel(float32Buffer, 0);
360
+
361
  const source = playbackCtx.createBufferSource();
362
  source.buffer = audioBuffer;
363
 
364
+ // Connect through analyser for visualization if available
365
  if (analyser) {
366
  source.connect(analyser);
367
  analyser.connect(playbackCtx.destination);
 
371
  } else {
372
  source.connect(playbackCtx.destination);
373
  }
374
+
375
+ // Schedule the audio to play at the right time
376
+ const startTime = Math.max(nextPlaybackTime, playbackCtx.currentTime);
377
+ source.start(startTime);
378
+ nextPlaybackTime = startTime + audioBuffer.duration;
379
+
380
+ // Keep track of scheduled sources
381
  scheduledSources.push(source);
382
+
383
+ // Clean up source when it finishes playing
384
  source.onended = () => {
385
  const index = scheduledSources.indexOf(source);
386
  if (index > -1) {
387
  scheduledSources.splice(index, 1);
388
  }
389
+ // Stop visualizer if no more audio playing
390
  if (scheduledSources.length === 0) {
391
  stopVisualizer();
392
  }
393
  };
394
+ }
395
+ else if (messageType === 'function_call') {
396
+ // Handle function calls from server
397
+ logMessage('function', 'Received function call:', message.payload);
398
+ const functionCalls = message.payload.functionCalls;
399
+
400
+ for (const fc of functionCalls) {
401
+ const functionName = fc.name;
402
+ const functionId = fc.id;
403
+
404
+ // Handle different functions
405
+ let result = 'ok';
406
+
407
+ if (functionName === 'turn_on_the_lights') {
408
+ logMessage('function', 'Turning on the lights');
409
+ // Simulate turning on lights
410
+ result = 'Lights turned on successfully';
411
+ }
412
+ else if (functionName === 'turn_off_the_lights') {
413
+ logMessage('function', 'Turning off the lights');
414
+ // Simulate turning off lights
415
+ result = 'Lights turned off successfully';
416
+ }
417
+
418
+ // Send response back to server
419
+ const response = {
420
+ type: 'tool_call_response',
421
+ payload: {
422
+ id: functionId,
423
+ name: functionName,
424
+ response: {
425
+ result: {
426
+ string_value: result
427
+ }
428
+ }
429
+ }
430
+ };
431
+
432
+ socket.send(JSON.stringify(response));
433
+ logMessage('function', 'Sent function response:', response);
434
+ }
435
  }
436
+ };
437
+
438
+ socket.onerror = (error) => {
439
+ logMessage('error', 'WebSocket error:', error);
440
+ };
441
+ } catch (error) {
442
+ logMessage('error', 'Failed to connect:', error);
443
+ }
444
  }
445
 
446
  async function startCapture() {
webapp.py CHANGED
@@ -53,6 +53,7 @@ async def websocket_endpoint(websocket: WebSocket):
53
 
54
  # Handle audio data from client
55
  if msg_type == "audio":
 
56
  raw_pcm = base64.b64decode(msg["payload"])
57
  forward_msg = {
58
  "realtime_input": {
@@ -97,29 +98,35 @@ async def websocket_endpoint(websocket: WebSocket):
97
  }
98
  await audio_loop.out_queue.put(forward_msg)
99
 
 
 
 
 
100
  else:
101
  print("[from_client_to_gemini] Unknown message type:", msg_type)
102
 
103
  except WebSocketDisconnect:
104
  print("[from_client_to_gemini] Client disconnected.")
105
- #del audio_loop
106
  loop_task.cancel()
107
  except Exception as e:
108
  print("[from_client_to_gemini] Error:", e)
109
 
110
  async def from_gemini_to_client():
111
- """Reads PCM audio from Gemini and sends it back to the client."""
112
  try:
113
  while True:
114
- pcm_data = await audio_loop.audio_in_queue.get()
115
- b64_pcm = base64.b64encode(pcm_data).decode()
116
-
117
- out_msg = {
118
- "type": "audio",
119
- "payload": b64_pcm
120
- }
121
- print("[from_gemini_to_client] Sending audio chunk to client. Size:", len(pcm_data))
122
- await websocket.send_text(json.dumps(out_msg))
 
 
 
123
 
124
  except WebSocketDisconnect:
125
  print("[from_gemini_to_client] Client disconnected.")
@@ -143,5 +150,5 @@ async def websocket_endpoint(websocket: WebSocket):
143
  pass
144
  print("[websocket_endpoint] Cleaned up AudioLoop for client")
145
 
146
- if __name__ == "__main__":
147
  uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
 
53
 
54
  # Handle audio data from client
55
  if msg_type == "audio":
56
+ # Decode base64 audio from client
57
  raw_pcm = base64.b64decode(msg["payload"])
58
  forward_msg = {
59
  "realtime_input": {
 
98
  }
99
  await audio_loop.out_queue.put(forward_msg)
100
 
101
+ elif msg_type == "tool_call_response":
102
+ # Handle tool call response from client
103
+ await audio_loop.handle_tool_call(msg["payload"])
104
+
105
  else:
106
  print("[from_client_to_gemini] Unknown message type:", msg_type)
107
 
108
  except WebSocketDisconnect:
109
  print("[from_client_to_gemini] Client disconnected.")
 
110
  loop_task.cancel()
111
  except Exception as e:
112
  print("[from_client_to_gemini] Error:", e)
113
 
114
  async def from_gemini_to_client():
115
+ """Reads messages from Gemini and sends them back to the client."""
116
  try:
117
  while True:
118
+ message = await audio_loop.audio_in_queue.get()
119
+ message_type = message["type"]
120
+
121
+ if message_type == "audio":
122
+ # Audio data is already base64 encoded from handler.py
123
+ await websocket.send_text(json.dumps(message))
124
+ print("[from_gemini_to_client] Sending audio chunk to client")
125
+
126
+ elif message_type == "function_call":
127
+ # Forward function call to client
128
+ await websocket.send_text(json.dumps(message))
129
+ print("[from_gemini_to_client] Forwarding function call to client")
130
 
131
  except WebSocketDisconnect:
132
  print("[from_gemini_to_client] Client disconnected.")
 
150
  pass
151
  print("[websocket_endpoint] Cleaned up AudioLoop for client")
152
 
153
+ if __name__ == "__main__":
154
  uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)