Spaces:
Build error
Build error
File size: 4,953 Bytes
30a4774 2ec08a6 30a4774 c6d11c0 2ec08a6 30a4774 2ec08a6 30a4774 9d219fb 30a4774 9d219fb 30a4774 9d219fb b88286b cfb8318 b88286b 2ec08a6 b88286b 2ec08a6 30a4774 cfb8318 b88286b 9d219fb b88286b 30a4774 cfb8318 2ec08a6 30a4774 b88286b 30a4774 cfb8318 30a4774 2ec08a6 30a4774 9d219fb cfb8318 30a4774 9d219fb 30a4774 b88286b 30a4774 9d219fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import asyncio
import base64
import os
import time
from io import BytesIO
import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
WebRTC,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from gradio.utils import get_space
from PIL import Image
load_dotenv()
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_image(data: np.ndarray) -> dict:
with BytesIO() as output_bytes:
pil_image = Image.fromarray(data)
pil_image.save(output_bytes, "JPEG")
bytes_data = output_bytes.getvalue()
base64_str = str(base64.b64encode(bytes_data), "utf-8")
return {"mime_type": "image/jpeg", "data": base64_str}
class GeminiHandler(AsyncAudioVideoStreamHandler):
def __init__(
self,
) -> None:
super().__init__(
"mono",
output_sample_rate=24000,
input_sample_rate=16000,
)
self.audio_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.session = None
self.last_frame_time = 0
self.quit = asyncio.Event()
def copy(self) -> "GeminiHandler":
return GeminiHandler()
async def start_up(self):
await self.wait_for_args()
api_key = self.latest_args[3]
client = genai.Client(
api_key=api_key, http_options={"api_version": "v1alpha"}
)
config = {"response_modalities": ["AUDIO"], "system_instruction": "You are an art history teacher that will describe the artwork passed in as an image to the user. Describe the history and significance of the artwork."}
async with client.aio.live.connect(
model="gemini-2.0-flash-exp",
config=config, # type: ignore
) as session:
self.session = session
while not self.quit.is_set():
turn = self.session.receive()
try:
async for response in turn:
if data := response.data:
audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
self.audio_queue.put_nowait(audio)
except websockets.exceptions.ConnectionClosedOK:
print("connection closed")
break
async def video_receive(self, frame: np.ndarray):
self.video_queue.put_nowait(frame)
if self.session:
# send image every 1 second
print(time.time() - self.last_frame_time)
if time.time() - self.last_frame_time > 1:
self.last_frame_time = time.time()
await self.session.send(input=encode_image(frame))
if self.latest_args[2] is not None:
await self.session.send(input=encode_image(self.latest_args[2]))
async def video_emit(self):
frame = await wait_for_item(self.video_queue, 0.01)
if frame is not None:
return frame
else:
return np.zeros((100, 100, 3), dtype=np.uint8)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
_, array = frame
array = array.squeeze()
audio_message = encode_audio(array)
if self.session:
await self.session.send(input=audio_message)
async def emit(self):
array = await wait_for_item(self.audio_queue, 0.01)
if array is not None:
return (self.output_sample_rate, array)
return array
async def shutdown(self) -> None:
if self.session:
self.quit.set()
await self.session.close()
self.quit.clear()
stream = Stream(
handler=GeminiHandler(),
modality="audio-video",
mode="send-receive",
rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None,
time_limit=180 if get_space() else None,
additional_inputs=[
gr.Markdown(
"## 🎨 Art History Teacher\n\n"
"Provide an image of the artwork and Gemini will describe it to you."
"To get a Gemini API key, please visit the [Gemini API Key](https://console.cloud.google.com/apis/api/generativelanguage.googleapis.com/credentials) page."
),
gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]),
gr.Textbox(label="Gemini API Key"),
],
ui_args={
"icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
"pulse_color": "rgb(255, 255, 255)",
"icon_button_color": "rgb(255, 255, 255)",
"title": "Gemini Audio Video Chat",
},
)
if __name__ == "__main__":
stream.ui.launch()
|