File size: 4,953 Bytes
30a4774
 
 
 
 
 
 
 
2ec08a6
30a4774
 
 
 
 
c6d11c0
2ec08a6
30a4774
2ec08a6
 
30a4774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d219fb
 
30a4774
9d219fb
30a4774
9d219fb
b88286b
cfb8318
 
b88286b
 
 
 
2ec08a6
 
 
 
b88286b
2ec08a6
 
 
30a4774
 
cfb8318
 
b88286b
 
 
 
 
 
9d219fb
 
b88286b
30a4774
cfb8318
2ec08a6
 
 
 
30a4774
 
 
 
 
 
b88286b
30a4774
 
cfb8318
 
 
 
30a4774
 
 
 
2ec08a6
30a4774
 
 
 
 
 
 
9d219fb
cfb8318
30a4774
9d219fb
 
 
 
 
 
 
30a4774
 
 
b88286b
 
30a4774
 
 
 
 
9d219fb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
import base64
import os
import time
from io import BytesIO

import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
    AsyncAudioVideoStreamHandler,
    Stream,
    WebRTC,
    get_cloudflare_turn_credentials_async,
    wait_for_item,
)
from google import genai
from gradio.utils import get_space
from PIL import Image

load_dotenv()


def encode_audio(data: np.ndarray) -> dict:
    """Encode Audio data to send to the server"""
    return {
        "mime_type": "audio/pcm",
        "data": base64.b64encode(data.tobytes()).decode("UTF-8"),
    }


def encode_image(data: np.ndarray) -> dict:
    with BytesIO() as output_bytes:
        pil_image = Image.fromarray(data)
        pil_image.save(output_bytes, "JPEG")
        bytes_data = output_bytes.getvalue()
    base64_str = str(base64.b64encode(bytes_data), "utf-8")
    return {"mime_type": "image/jpeg", "data": base64_str}


class GeminiHandler(AsyncAudioVideoStreamHandler):
    def __init__(
        self,
    ) -> None:
        super().__init__(
            "mono",
            output_sample_rate=24000,
            input_sample_rate=16000,
        )
        self.audio_queue = asyncio.Queue()
        self.video_queue = asyncio.Queue()
        self.session = None
        self.last_frame_time = 0
        self.quit = asyncio.Event()

    def copy(self) -> "GeminiHandler":
        return GeminiHandler()

    async def start_up(self):
        await self.wait_for_args()
        api_key = self.latest_args[3]
        client = genai.Client(
            api_key=api_key, http_options={"api_version": "v1alpha"}
        )
        config = {"response_modalities": ["AUDIO"], "system_instruction": "You are an art history teacher that will describe the artwork passed in as an image to the user. Describe the history and significance of the artwork."}
        async with client.aio.live.connect(
            model="gemini-2.0-flash-exp",
            config=config,  # type: ignore
        ) as session:
            self.session = session
            while not self.quit.is_set():
                turn = self.session.receive()
                try:
                    async for response in turn:
                        if data := response.data:
                            audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1)
                        self.audio_queue.put_nowait(audio)
                except websockets.exceptions.ConnectionClosedOK:
                    print("connection closed")
                    break

    async def video_receive(self, frame: np.ndarray):
        self.video_queue.put_nowait(frame)

        if self.session:
            # send image every 1 second
            print(time.time() - self.last_frame_time)
            if time.time() - self.last_frame_time > 1:
                self.last_frame_time = time.time()
                await self.session.send(input=encode_image(frame))
                if self.latest_args[2] is not None:
                    await self.session.send(input=encode_image(self.latest_args[2]))

    async def video_emit(self):
        frame = await wait_for_item(self.video_queue, 0.01)
        if frame is not None:
            return frame
        else:
            return np.zeros((100, 100, 3), dtype=np.uint8)

    async def receive(self, frame: tuple[int, np.ndarray]) -> None:
        _, array = frame
        array = array.squeeze()
        audio_message = encode_audio(array)
        if self.session:
            await self.session.send(input=audio_message)

    async def emit(self):
        array = await wait_for_item(self.audio_queue, 0.01)
        if array is not None:
            return (self.output_sample_rate, array)
        return array

    async def shutdown(self) -> None:
        if self.session:
            self.quit.set()
            await self.session.close()
            self.quit.clear()


stream = Stream(
    handler=GeminiHandler(),
    modality="audio-video",
    mode="send-receive",
    rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None,
    time_limit=180 if get_space() else None,
    additional_inputs=[
        gr.Markdown(
            "## 🎨 Art History Teacher\n\n"
            "Provide an image of the artwork and Gemini will describe it to you."
            "To get a Gemini API key, please visit the [Gemini API Key](https://console.cloud.google.com/apis/api/generativelanguage.googleapis.com/credentials) page."
        ),
        gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]),
        gr.Textbox(label="Gemini API Key"),
    ],
    ui_args={
        "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
        "pulse_color": "rgb(255, 255, 255)",
        "icon_button_color": "rgb(255, 255, 255)",
        "title": "Gemini Audio Video Chat",
    },
)

if __name__ == "__main__":
    stream.ui.launch()