Ali Sartaz Khan commited on
Commit
3c8c320
·
1 Parent(s): c2b5b47

Add application file

Browse files
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
app.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from talk_arena.audio_collection import demo
3
+ import sys
4
+ sys.path.append("talk-arena")
5
+ from talk_arena.audio_collection import demo
6
+
7
+ demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
8
+
audio_out_votes.json ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.45.2
2
+ transformers-stream-generator==0.0.5
3
+ accelerate>=0.26.0
4
+ peft
5
+ gradio==5.8.0
6
+ tinydb==4.8.0
7
+ xxhash==3.4.1
8
+ google-ai-generativelanguage==0.6.10
9
+ google-generativeai
10
+ datasets==2.18.0
11
+ librosa==0.10.1
12
+ soundfile==0.12.1
13
+ openai==1.52.0
14
+ python-dotenv==1.0.1
15
+ httpx==0.27.2
talk_arena/.env ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ OPENAI_API_KEY="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"
2
+ GEMINI_API_KEY="AIzaSyAM6XTT9S9nzE09jj5o-UNDZ4f8INPyWBM"
talk_arena/__init__.py ADDED
File without changes
talk_arena/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (175 Bytes). View file
 
talk_arena/__pycache__/db_utils.cpython-312.pyc ADDED
Binary file (2.62 kB). View file
 
talk_arena/__pycache__/streaming_helpers.cpython-312.pyc ADDED
Binary file (18.8 kB). View file
 
talk_arena/audio_collection.py ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import os
4
+ import random
5
+ import textwrap
6
+ import time
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import xxhash
12
+ from datasets import Audio
13
+ from dotenv import load_dotenv
14
+ from openai import OpenAI
15
+
16
+ import talk_arena.streaming_helpers as sh
17
+ from talk_arena.db_utils import TinyThreadSafeDB
18
+
19
+ load_dotenv()
20
+ resampler = Audio(sampling_rate=16_000)
21
+
22
+ # /nlp/scr/askhan1/audioLLM_as_a_judge/talk-arena../src/talk_arena/audio_collection.py
23
+ def parse_args():
24
+ parser = argparse.ArgumentParser(description="Talk Arena Demo")
25
+ parser.add_argument("--free_only", action="store_true", help="Only use free models")
26
+ return parser.parse_args()
27
+
28
+
29
+ args = parse_args()
30
+
31
+ if gr.NO_RELOAD: # Prevents Re-init during hot reloading
32
+ # Transcription Disabled for Public Interface
33
+ # asr_pipe = pipeline(
34
+ # task="automatic-speech-recognition",
35
+ # model="openai/whisper-large-v3-turbo",
36
+ # chunk_length_s=30,
37
+ # device="cuda:1",
38
+ # )
39
+
40
+ anonymous = True
41
+
42
+ gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o")
43
+ gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp")
44
+ competitor_info = [
45
+ (sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"),
46
+ (sh.gradio_gen_factory(gemini2_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"),
47
+ ]
48
+
49
+ resp_generators = [generator for generator, _, _ in competitor_info]
50
+ model_shorthand = [shorthand for _, shorthand, _ in competitor_info]
51
+ model_name = [full_name for _, _, full_name in competitor_info]
52
+ all_models = list(range(len(model_shorthand)))
53
+
54
+
55
+ async def pairwise_response_async(audio_input, state, model_order):
56
+ if audio_input == None:
57
+ raise StopAsyncIteration(
58
+ "",
59
+ "",
60
+ gr.Button(visible=False),
61
+ gr.Button(visible=False),
62
+ gr.Button(visible=False),
63
+ state,
64
+ audio_input,
65
+ None,
66
+ None,
67
+ None,
68
+ )
69
+ spinner_id = 0
70
+ spinners = ["◐ ", "◓ ", "◑", "◒"]
71
+ spinner = spinners[0]
72
+ gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]]
73
+ latencies = [{}, {}] # Store timing info for each model
74
+ resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)]
75
+ tts_resps = [gr.Audio(), gr.Audio()]
76
+ error_in_model = False
77
+ for order, generator in enumerate(gen_pair):
78
+ start_time = time.time()
79
+ first_token = True
80
+ total_length = 0
81
+ try:
82
+ async for local_resp in generator(audio_input, order):
83
+ total_length += 1
84
+ if first_token:
85
+ latencies[order]["time_to_first_token"] = time.time() - start_time
86
+ first_token = False
87
+ resps[order] = local_resp
88
+ spinner = spinners[spinner_id]
89
+ spinner_id = (spinner_id + 1) % 4
90
+ yield (
91
+ gr.Button(
92
+ value=spinner + " Generating Responses " + spinner,
93
+ interactive=False,
94
+ variant="primary",
95
+ ),
96
+ resps[0],
97
+ resps[1],
98
+ tts_resps[0],
99
+ tts_resps[1],
100
+ gr.Button(visible=False),
101
+ gr.Button(visible=False),
102
+ gr.Button(visible=False),
103
+ state,
104
+ audio_input,
105
+ None,
106
+ None,
107
+ latencies,
108
+ )
109
+ latencies[order]["total_time"] = time.time() - start_time
110
+ latencies[order]["response_length"] = total_length
111
+ except:
112
+ error_in_model = True
113
+ resps[order] = gr.Textbox(
114
+ info=f"<strong>Error thrown by Model {order+1} API</strong>",
115
+ value="" if first_token else resps[order]._constructor_args[0]["value"],
116
+ visible=True,
117
+ label=f"Model {order+1}",
118
+ )
119
+ yield (
120
+ gr.Button(
121
+ value=spinner + " Generating Responses " + spinner,
122
+ interactive=False,
123
+ variant="primary",
124
+ ),
125
+ resps[0],
126
+ resps[1],
127
+ tts_resps[0],
128
+ tts_resps[1],
129
+ gr.Button(visible=False),
130
+ gr.Button(visible=False),
131
+ gr.Button(visible=False),
132
+ state,
133
+ audio_input,
134
+ None,
135
+ None,
136
+ latencies,
137
+ )
138
+
139
+ sr, y = audio_input
140
+ x = xxhash.xxh32(bytes(y)).hexdigest()
141
+ y = y.astype(np.float32)
142
+ y /= np.max(np.abs(y))
143
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
144
+ sf.write(f"{x}_resp{order}.wav", a["array"], a["sampling_rate"], format="wav")
145
+ tts_options = {
146
+ "model": "gpt-4o-mini-tts",
147
+ "voice": "alloy",
148
+ "input": resps[order].__dict__["_constructor_args"][0]["value"],
149
+ "response_format": "wav",
150
+ }
151
+ abytes = OpenAI(api_key=os.environ["OPENAI_API_KEY"]).audio.speech.create(**tts_options).content
152
+ tts_resps[order] = gr.Audio(
153
+ value=abytes,
154
+ visible=True,
155
+ )
156
+ latencies[order]["total_time"] = time.time() - start_time
157
+ latencies[order]["response_length"] = total_length
158
+ print(latencies)
159
+ yield (
160
+ gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False),
161
+ resps[0],
162
+ resps[1],
163
+ tts_resps[0],
164
+ tts_resps[1],
165
+ gr.Button(visible=not error_in_model),
166
+ gr.Button(visible=not error_in_model),
167
+ gr.Button(visible=not error_in_model),
168
+ responses_complete(state),
169
+ audio_input,
170
+ gr.Textbox(visible=False),
171
+ gr.Audio(visible=False),
172
+ latencies,
173
+ )
174
+
175
+
176
+ def on_page_load(state, model_order):
177
+ if state == 0:
178
+ # gr.Info(
179
+ # "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant,"
180
+ # " or ChatGPT for."
181
+ # )
182
+ state = 1
183
+ model_order = random.sample(all_models, 2) if anonymous else model_order
184
+ return state, model_order
185
+
186
+
187
+ def recording_complete(state):
188
+ if state == 1:
189
+ # gr.Info(
190
+ # "Once you submit your recording, you'll receive responses from different models. This might take a second."
191
+ # )
192
+ state = 2
193
+ return (
194
+ gr.Button(value="Starting Generation", interactive=False, variant="primary"),
195
+ state,
196
+ )
197
+
198
+
199
+ def responses_complete(state):
200
+ if state == 2:
201
+ gr.Info(
202
+ "Give us your feedback! Mark which model gave you the best response so we can understand the quality of"
203
+ " these different voice assistant models."
204
+ )
205
+ state = 3
206
+ return state
207
+
208
+
209
+ def clear_factory(button_id):
210
+ async def clear(audio_input, model_order, pref_counter, reasoning, latency):
211
+ textbox1 = gr.Textbox(visible=False)
212
+ textbox2 = gr.Textbox(visible=False)
213
+ if button_id != None:
214
+ sr, y = audio_input
215
+ x = xxhash.xxh32(bytes(y)).hexdigest()
216
+ await db.insert(
217
+ {
218
+ "audio_hash": x,
219
+ "outcome": button_id,
220
+ "model_a": model_shorthand[model_order[0]],
221
+ "model_b": model_shorthand[model_order[1]],
222
+ "why": reasoning,
223
+ "model_a_latency": latency[0],
224
+ "model_b_latency": latency[1],
225
+ }
226
+ )
227
+ pref_counter += 1
228
+ model_a = model_name[model_order[0]]
229
+ model_b = model_name[model_order[1]]
230
+
231
+ counter_text = f"# {pref_counter}/10 Preferences Submitted"
232
+ if pref_counter >= 10:
233
+ code = "C1ARB3D6"
234
+ counter_text = f"# Completed! Completion Code: {code}"
235
+ if anonymous:
236
+ model_order = random.sample(all_models, 2)
237
+ return (
238
+ model_order,
239
+ gr.Button(
240
+ value="Record Audio to Submit Again!",
241
+ interactive=False,
242
+ visible=True,
243
+ ),
244
+ gr.Button(visible=False),
245
+ gr.Button(visible=False),
246
+ gr.Button(visible=False),
247
+ None,
248
+ textbox1,
249
+ textbox2,
250
+ gr.Audio(visible=False),
251
+ gr.Audio(visible=False),
252
+ pref_counter,
253
+ counter_text,
254
+ gr.Textbox(visible=False),
255
+ gr.Audio(visible=False),
256
+ )
257
+
258
+ return clear
259
+
260
+
261
+ def transcribe(transc, voice_reason):
262
+ if transc is None:
263
+ transc = ""
264
+ transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
265
+ return transc, gr.Audio(value=None)
266
+
267
+
268
+ theme = gr.themes.Soft(
269
+ primary_hue=gr.themes.Color(
270
+ c100="#82000019",
271
+ c200="#82000033",
272
+ c300="#8200004c",
273
+ c400="#82000066",
274
+ c50="#8200007f",
275
+ c500="#8200007f",
276
+ c600="#82000099",
277
+ c700="#820000b2",
278
+ c800="#820000cc",
279
+ c900="#820000e5",
280
+ c950="#820000f2",
281
+ ),
282
+ secondary_hue="rose",
283
+ neutral_hue="stone",
284
+ )
285
+
286
+ with open("../src/talk_arena/styles.css", "r") as css_file:
287
+ custom_css = css_file.read()
288
+
289
+ db = TinyThreadSafeDB("audio_out_votes.json")
290
+
291
+ with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo:
292
+ submitted_preferences = gr.State(0)
293
+ state = gr.State(0)
294
+ model_order = gr.State([])
295
+ latency = gr.State([])
296
+ with gr.Row():
297
+ counter_text = gr.Markdown(
298
+ "# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference."
299
+ )
300
+ category_description_text = gr.Markdown("PLACEHOLDER FOR ALI TO FILL IN LATER")
301
+ with gr.Row():
302
+ audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input")
303
+
304
+ with gr.Row(equal_height=True):
305
+ with gr.Column(scale=1):
306
+ out1 = gr.Textbox(visible=False, lines=5, autoscroll=True)
307
+ audio_out1 = gr.Audio(visible=False)
308
+ with gr.Column(scale=1):
309
+ out2 = gr.Textbox(visible=False, lines=5, autoscroll=True)
310
+ audio_out2 = gr.Audio(visible=False)
311
+
312
+ with gr.Row():
313
+ btn = gr.Button(value="Record Audio to Submit!", interactive=False)
314
+
315
+ with gr.Row(equal_height=True):
316
+ reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4)
317
+ reason_record = gr.Audio(
318
+ sources=["microphone"],
319
+ interactive=True,
320
+ streaming=False,
321
+ label="Speak to transcribe!",
322
+ visible=False,
323
+ type="filepath",
324
+ # waveform_options={"show_recording_waveform": False},
325
+ scale=1,
326
+ )
327
+
328
+ with gr.Row():
329
+ best1 = gr.Button(value="Model 1 is better", visible=False)
330
+ tie = gr.Button(value="Tie", visible=False)
331
+ best2 = gr.Button(value="Model 2 is better", visible=False)
332
+
333
+ with gr.Row():
334
+ contact = gr.Markdown("")
335
+
336
+ # reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record])
337
+ audio_input.stop_recording(
338
+ recording_complete,
339
+ [state],
340
+ [btn, state],
341
+ ).then(
342
+ fn=pairwise_response_async,
343
+ inputs=[audio_input, state, model_order],
344
+ outputs=[
345
+ btn,
346
+ out1,
347
+ out2,
348
+ audio_out1,
349
+ audio_out2,
350
+ best1,
351
+ best2,
352
+ tie,
353
+ state,
354
+ audio_input,
355
+ reason,
356
+ reason_record,
357
+ latency,
358
+ ],
359
+ )
360
+ audio_input.start_recording(
361
+ lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"),
362
+ None,
363
+ btn,
364
+ )
365
+ best1.click(
366
+ fn=clear_factory(0),
367
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
368
+ outputs=[
369
+ model_order,
370
+ btn,
371
+ best1,
372
+ best2,
373
+ tie,
374
+ audio_input,
375
+ out1,
376
+ out2,
377
+ audio_out1,
378
+ audio_out2,
379
+ submitted_preferences,
380
+ counter_text,
381
+ reason,
382
+ reason_record,
383
+ ],
384
+ )
385
+ tie.click(
386
+ fn=clear_factory(0.5),
387
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
388
+ outputs=[
389
+ model_order,
390
+ btn,
391
+ best1,
392
+ best2,
393
+ tie,
394
+ audio_input,
395
+ out1,
396
+ out2,
397
+ audio_out1,
398
+ audio_out2,
399
+ submitted_preferences,
400
+ counter_text,
401
+ reason,
402
+ reason_record,
403
+ ],
404
+ )
405
+ best2.click(
406
+ fn=clear_factory(1),
407
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
408
+ outputs=[
409
+ model_order,
410
+ btn,
411
+ best1,
412
+ best2,
413
+ tie,
414
+ audio_input,
415
+ out1,
416
+ out2,
417
+ audio_out1,
418
+ audio_out2,
419
+ submitted_preferences,
420
+ counter_text,
421
+ reason,
422
+ reason_record,
423
+ ],
424
+ )
425
+ audio_input.clear(
426
+ clear_factory(None),
427
+ [audio_input, model_order, submitted_preferences, reason, latency],
428
+ [
429
+ model_order,
430
+ btn,
431
+ best1,
432
+ best2,
433
+ tie,
434
+ audio_input,
435
+ out1,
436
+ out2,
437
+ audio_out1,
438
+ audio_out2,
439
+ submitted_preferences,
440
+ counter_text,
441
+ reason,
442
+ reason_record,
443
+ ],
444
+ )
445
+ demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order])
446
+
447
+ if __name__ == "__main__":
448
+ demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
talk_arena/db_utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from asyncio import Lock as ALock
3
+ from contextlib import asynccontextmanager
4
+ from threading import Lock as TLock
5
+
6
+ from tinydb import TinyDB
7
+ from tinydb.table import Table as TinyDBTable
8
+
9
+
10
+ class UUIDTable(TinyDBTable):
11
+ document_id_class = uuid.UUID
12
+
13
+ def _get_next_id(self):
14
+ return uuid.uuid4()
15
+
16
+
17
+ class UUIDB(TinyDB):
18
+ table_class = UUIDTable
19
+
20
+
21
+ class TinyThreadSafeDB:
22
+ def __init__(self, db_path: str):
23
+ self.db = UUIDB(db_path)
24
+ self._lock1 = TLock()
25
+ self._lock2 = ALock()
26
+
27
+ @asynccontextmanager
28
+ async def atomic_operation(self):
29
+ """Context manager for thread-safe database operations"""
30
+ with self._lock1:
31
+ async with self._lock2:
32
+ yield self.db
33
+
34
+ async def insert(self, data: dict):
35
+ """Thread-safe insertion of preference data"""
36
+ async with self.atomic_operation() as db:
37
+ db.insert(data)
talk_arena/demo.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import random
4
+ import textwrap
5
+ import time
6
+
7
+ import gradio as gr
8
+ import xxhash
9
+ from dotenv import load_dotenv
10
+ from transformers import pipeline
11
+
12
+ import talk_arena.streaming_helpers as sh
13
+ from talk_arena.db_utils import TinyThreadSafeDB
14
+
15
+
16
+ load_dotenv()
17
+
18
+
19
+ def parse_args():
20
+ parser = argparse.ArgumentParser(description="Talk Arena Demo")
21
+ parser.add_argument("--free_only", action="store_true", help="Only use free models")
22
+ return parser.parse_args()
23
+
24
+
25
+ args = parse_args()
26
+
27
+ if gr.NO_RELOAD: # Prevents Re-init during hot reloading
28
+ # Transcription Disabled for Public Interface
29
+ # asr_pipe = pipeline(
30
+ # task="automatic-speech-recognition",
31
+ # model="openai/whisper-large-v3-turbo",
32
+ # chunk_length_s=30,
33
+ # device="cuda:1",
34
+ # )
35
+
36
+ anonymous = True
37
+
38
+ # Generation Setup
39
+ diva_audio, diva = sh.api_streaming("WillHeld/DiVA-llama-3-v0-8b")
40
+ qwen2_audio, qwen2 = sh.api_streaming("Qwen/Qwen2-Audio-7B-Instruct")
41
+ pipelined_system, pipeline_model = sh.api_streaming("pipeline/meta-llama/Meta-Llama-3-8B-Instruct")
42
+ if not args.free_only:
43
+ gemini_audio, gemini_model = sh.gemini_streaming("models/gemini-1.5-flash")
44
+ gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o")
45
+ geminip_audio, geminip_model = sh.gemini_streaming("models/gemini-1.5-pro")
46
+ gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp")
47
+ typhoon_audio, typhoon_model = sh.api_streaming("scb10x/llama-3-typhoon-audio-8b-2411")
48
+
49
+ competitor_info = [
50
+ (sh.gradio_gen_factory(diva_audio, "DiVA Llama 3 8B", anonymous), "diva_3_8b", "DiVA Llama 3 8B"),
51
+ (sh.gradio_gen_factory(qwen2_audio, "Qwen 2", anonymous), "qwen2", "Qwen 2 Audio"),
52
+ (
53
+ sh.gradio_gen_factory(pipelined_system, "Pipelined Llama 3 8B", anonymous),
54
+ "pipe_l3.0",
55
+ "Pipelined Llama 3 8B",
56
+ ),
57
+ (sh.gradio_gen_factory(typhoon_audio, "Typhoon Audio", anonymous), "typhoon_audio", "Typhoon Audio"),
58
+ ]
59
+ # Add paid models if flag is not set
60
+ if not args.free_only:
61
+ competitor_info += [
62
+ (sh.gradio_gen_factory(gemini_audio, "Gemini 1.5 Flash", anonymous), "gemini_1.5f", "Gemini 1.5 Flash"),
63
+ (sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"),
64
+ (sh.gradio_gen_factory(geminip_audio, "Gemini 1.5 Pro", anonymous), "gemini_1.5p", "Gemini 1.5 Pro"),
65
+ (sh.gradio_gen_factory(geminip_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"),
66
+ ]
67
+
68
+ resp_generators = [generator for generator, _, _ in competitor_info]
69
+ model_shorthand = [shorthand for _, shorthand, _ in competitor_info]
70
+ model_name = [full_name for _, _, full_name in competitor_info]
71
+ all_models = list(range(len(model_shorthand)))
72
+
73
+
74
+ async def pairwise_response_async(audio_input, state, model_order):
75
+ if audio_input == None:
76
+ raise StopAsyncIteration(
77
+ "",
78
+ "",
79
+ gr.Button(visible=False),
80
+ gr.Button(visible=False),
81
+ gr.Button(visible=False),
82
+ state,
83
+ audio_input,
84
+ None,
85
+ None,
86
+ None,
87
+ )
88
+ spinner_id = 0
89
+ spinners = ["◐ ", "◓ ", "◑", "◒"]
90
+ spinner = spinners[0]
91
+ gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]]
92
+ latencies = [{}, {}] # Store timing info for each model
93
+ resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)]
94
+
95
+ error_in_model = False
96
+ for order, generator in enumerate(gen_pair):
97
+ start_time = time.time()
98
+ first_token = True
99
+ total_length = 0
100
+ try:
101
+ async for local_resp in generator(audio_input, order):
102
+ total_length += 1
103
+ if first_token:
104
+ latencies[order]["time_to_first_token"] = time.time() - start_time
105
+ first_token = False
106
+ resps[order] = local_resp
107
+ spinner = spinners[spinner_id]
108
+ spinner_id = (spinner_id + 1) % 4
109
+ yield (
110
+ gr.Button(
111
+ value=spinner + " Generating Responses " + spinner,
112
+ interactive=False,
113
+ variant="primary",
114
+ ),
115
+ resps[0],
116
+ resps[1],
117
+ gr.Button(visible=False),
118
+ gr.Button(visible=False),
119
+ gr.Button(visible=False),
120
+ state,
121
+ audio_input,
122
+ None,
123
+ None,
124
+ latencies,
125
+ )
126
+ latencies[order]["total_time"] = time.time() - start_time
127
+ latencies[order]["response_length"] = total_length
128
+ except:
129
+ error_in_model = True
130
+ resps[order] = gr.Textbox(
131
+ info=f"<strong>Error thrown by Model {order+1} API</strong>",
132
+ value="" if first_token else resps[order]._constructor_args[0]["value"],
133
+ visible=True,
134
+ label=f"Model {order+1}",
135
+ )
136
+ yield (
137
+ gr.Button(
138
+ value=spinner + " Generating Responses " + spinner,
139
+ interactive=False,
140
+ variant="primary",
141
+ ),
142
+ resps[0],
143
+ resps[1],
144
+ gr.Button(visible=False),
145
+ gr.Button(visible=False),
146
+ gr.Button(visible=False),
147
+ state,
148
+ audio_input,
149
+ None,
150
+ None,
151
+ latencies,
152
+ )
153
+ latencies[order]["total_time"] = time.time() - start_time
154
+ latencies[order]["response_length"] = total_length
155
+ print(latencies)
156
+ yield (
157
+ gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False),
158
+ resps[0],
159
+ resps[1],
160
+ gr.Button(visible=not error_in_model),
161
+ gr.Button(visible=not error_in_model),
162
+ gr.Button(visible=not error_in_model),
163
+ responses_complete(state),
164
+ audio_input,
165
+ gr.Textbox(visible=False),
166
+ gr.Audio(visible=False),
167
+ latencies,
168
+ )
169
+
170
+
171
+ def on_page_load(state, model_order):
172
+ if state == 0:
173
+ # gr.Info(
174
+ # "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant,"
175
+ # " or ChatGPT for."
176
+ # )
177
+ state = 1
178
+ model_order = random.sample(all_models, 2) if anonymous else model_order
179
+ return state, model_order
180
+
181
+
182
+ def recording_complete(state):
183
+ if state == 1:
184
+ # gr.Info(
185
+ # "Once you submit your recording, you'll receive responses from different models. This might take a second."
186
+ # )
187
+ state = 2
188
+ return (
189
+ gr.Button(value="Starting Generation", interactive=False, variant="primary"),
190
+ state,
191
+ )
192
+
193
+
194
+ def responses_complete(state):
195
+ if state == 2:
196
+ gr.Info(
197
+ "Give us your feedback! Mark which model gave you the best response so we can understand the quality of"
198
+ " these different voice assistant models."
199
+ )
200
+ state = 3
201
+ return state
202
+
203
+
204
+ def clear_factory(button_id):
205
+ async def clear(audio_input, model_order, pref_counter, reasoning, latency):
206
+ textbox1 = gr.Textbox(visible=False)
207
+ textbox2 = gr.Textbox(visible=False)
208
+ if button_id != None:
209
+ sr, y = audio_input
210
+ x = xxhash.xxh32(bytes(y)).hexdigest()
211
+ await db.insert(
212
+ {
213
+ "audio_hash": x,
214
+ "outcome": button_id,
215
+ "model_a": model_shorthand[model_order[0]],
216
+ "model_b": model_shorthand[model_order[1]],
217
+ "why": reasoning,
218
+ "model_a_latency": latency[0],
219
+ "model_b_latency": latency[1],
220
+ }
221
+ )
222
+ pref_counter += 1
223
+ model_a = model_name[model_order[0]]
224
+ model_b = model_name[model_order[1]]
225
+ textbox1 = gr.Textbox(
226
+ visible=True,
227
+ info=f"<strong style='color: #53565A'>Response from {model_a}</strong><p>Time-to-First-Character: {latency[0]['time_to_first_token']:.2f} ms, Time Per Character: {latency[0]['total_time']/latency[0]['response_length']:.2f} ms</p>",
228
+ )
229
+ textbox2 = gr.Textbox(
230
+ visible=True,
231
+ info=f"<strong style='color: #53565A'>Response from {model_b}</strong><p>Time-to-First-Character: {latency[1]['time_to_first_token']:.2f} ms, Time Per Character: {latency[1]['total_time']/latency[1]['response_length']:.2f} ms</p>",
232
+ )
233
+
234
+ try:
235
+ sr, y = audio_input
236
+ x = xxhash.xxh32(bytes(y)).hexdigest()
237
+ os.remove(f"{x}.wav")
238
+ except:
239
+ # file already deleted, this is just a failsafe to assure data is cleared
240
+ pass
241
+ counter_text = f"# {pref_counter}/10 Preferences Submitted"
242
+ if pref_counter >= 10 and False: # Currently Disabled, Manages Prolific Completionx
243
+ code = "PLACEHOLDER"
244
+ counter_text = f"# Completed! Completion Code: {code}"
245
+ counter_text = ""
246
+ if anonymous:
247
+ model_order = random.sample(all_models, 2)
248
+ return (
249
+ model_order,
250
+ gr.Button(
251
+ value="Record Audio to Submit Again!",
252
+ interactive=False,
253
+ visible=True,
254
+ ),
255
+ gr.Button(visible=False),
256
+ gr.Button(visible=False),
257
+ gr.Button(visible=False),
258
+ None,
259
+ textbox1,
260
+ textbox2,
261
+ pref_counter,
262
+ counter_text,
263
+ gr.Textbox(visible=False),
264
+ gr.Audio(visible=False),
265
+ )
266
+
267
+ return clear
268
+
269
+
270
+ def transcribe(transc, voice_reason):
271
+ if transc is None:
272
+ transc = ""
273
+ transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
274
+ return transc, gr.Audio(value=None)
275
+
276
+
277
+ theme = gr.themes.Soft(
278
+ primary_hue=gr.themes.Color(
279
+ c100="#82000019",
280
+ c200="#82000033",
281
+ c300="#8200004c",
282
+ c400="#82000066",
283
+ c50="#8200007f",
284
+ c500="#8200007f",
285
+ c600="#82000099",
286
+ c700="#820000b2",
287
+ c800="#820000cc",
288
+ c900="#820000e5",
289
+ c950="#820000f2",
290
+ ),
291
+ secondary_hue="rose",
292
+ neutral_hue="stone",
293
+ )
294
+
295
+ with open("src/talk_arena/styles.css", "r") as css_file:
296
+ custom_css = css_file.read()
297
+
298
+ db = TinyThreadSafeDB("live_votes.json")
299
+
300
+ with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo:
301
+ submitted_preferences = gr.State(0)
302
+ state = gr.State(0)
303
+ model_order = gr.State([])
304
+ latency = gr.State([])
305
+ with gr.Row():
306
+ counter_text = gr.Markdown(
307
+ ""
308
+ ) # "# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference.")
309
+ with gr.Row():
310
+ audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input")
311
+
312
+ with gr.Row(equal_height=True):
313
+ with gr.Column(scale=1):
314
+ out1 = gr.Textbox(visible=False, lines=5, autoscroll=True)
315
+ with gr.Column(scale=1):
316
+ out2 = gr.Textbox(visible=False, lines=5, autoscroll=True)
317
+
318
+ with gr.Row():
319
+ btn = gr.Button(value="Record Audio to Submit!", interactive=False)
320
+
321
+ with gr.Row(equal_height=True):
322
+ reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4)
323
+ reason_record = gr.Audio(
324
+ sources=["microphone"],
325
+ interactive=True,
326
+ streaming=False,
327
+ label="Speak to transcribe!",
328
+ visible=False,
329
+ type="filepath",
330
+ # waveform_options={"show_recording_waveform": False},
331
+ scale=1,
332
+ )
333
+
334
+ with gr.Row():
335
+ best1 = gr.Button(value="Model 1 is better", visible=False)
336
+ tie = gr.Button(value="Tie", visible=False)
337
+ best2 = gr.Button(value="Model 2 is better", visible=False)
338
+
339
+ with gr.Row():
340
+ contact = gr.Markdown("")
341
+
342
+ # reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record])
343
+ audio_input.stop_recording(
344
+ recording_complete,
345
+ [state],
346
+ [btn, state],
347
+ ).then(
348
+ fn=pairwise_response_async,
349
+ inputs=[audio_input, state, model_order],
350
+ outputs=[btn, out1, out2, best1, best2, tie, state, audio_input, reason, reason_record, latency],
351
+ )
352
+ audio_input.start_recording(
353
+ lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"),
354
+ None,
355
+ btn,
356
+ )
357
+ best1.click(
358
+ fn=clear_factory(0),
359
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
360
+ outputs=[
361
+ model_order,
362
+ btn,
363
+ best1,
364
+ best2,
365
+ tie,
366
+ audio_input,
367
+ out1,
368
+ out2,
369
+ submitted_preferences,
370
+ counter_text,
371
+ reason,
372
+ reason_record,
373
+ ],
374
+ )
375
+ tie.click(
376
+ fn=clear_factory(0.5),
377
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
378
+ outputs=[
379
+ model_order,
380
+ btn,
381
+ best1,
382
+ best2,
383
+ tie,
384
+ audio_input,
385
+ out1,
386
+ out2,
387
+ submitted_preferences,
388
+ counter_text,
389
+ reason,
390
+ reason_record,
391
+ ],
392
+ )
393
+ best2.click(
394
+ fn=clear_factory(1),
395
+ inputs=[audio_input, model_order, submitted_preferences, reason, latency],
396
+ outputs=[
397
+ model_order,
398
+ btn,
399
+ best1,
400
+ best2,
401
+ tie,
402
+ audio_input,
403
+ out1,
404
+ out2,
405
+ submitted_preferences,
406
+ counter_text,
407
+ reason,
408
+ reason_record,
409
+ ],
410
+ )
411
+ audio_input.clear(
412
+ clear_factory(None),
413
+ [audio_input, model_order, submitted_preferences, reason, latency],
414
+ [
415
+ model_order,
416
+ btn,
417
+ best1,
418
+ best2,
419
+ tie,
420
+ audio_input,
421
+ out1,
422
+ out2,
423
+ submitted_preferences,
424
+ counter_text,
425
+ reason,
426
+ reason_record,
427
+ ],
428
+ )
429
+ demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order])
430
+
431
+ if __name__ == "__main__":
432
+ demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
talk_arena/leaderboard_viz.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import textwrap
4
+ from collections import defaultdict
5
+ from datetime import datetime
6
+ from typing import Dict, List, Tuple
7
+ from zoneinfo import ZoneInfo
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import pandas as pd
12
+ import plotly.express as px
13
+ import plotly.io as pio
14
+ from apscheduler.schedulers.background import BackgroundScheduler
15
+ from scipy.optimize import minimize
16
+ from scipy.special import expit
17
+
18
+
19
+ # Constants
20
+ COLORS = [
21
+ "#1B7FFF",
22
+ "#F07D1A",
23
+ "#BA24C7",
24
+ "#FE42C7",
25
+ "#0D4B7C",
26
+ "#0EAC96",
27
+ "#AA7CFF",
28
+ "#B50550",
29
+ "#009EEB",
30
+ "#220B55",
31
+ "#7B3301",
32
+ ]
33
+ WR_PLOT = None
34
+ BT_PLOT = None
35
+ UPDATE_TIME = None
36
+ NAME_MAPPING = {
37
+ "gemini_2f": "Gemini 2.0 Flash (Experimental)",
38
+ "diva_3_8b": "DiVA Llama 3 8B",
39
+ "qwen2": "Qwen 2 Audio",
40
+ "pipe_l3.0": "Pipelined Llama 3 8B",
41
+ "gemini_1.5f": "Gemini 1.5 Flash",
42
+ "gpt4o": "GPT-4o",
43
+ "gemini_1.5p": "Gemini 1.5 Pro",
44
+ "typhoon_audio": "Typhoon Audio",
45
+ }
46
+
47
+
48
+ def get_aesthetic_timestamp():
49
+ """
50
+ Returns a beautifully formatted timestamp in the format:
51
+ 'Tuesday, December 10th, 2024 at 3:45 PM'
52
+ """
53
+ # Get timezone object for PST
54
+ pst = ZoneInfo("America/Los_Angeles")
55
+
56
+ # Get current time in PST
57
+ now = datetime.now(pst)
58
+
59
+ # Add suffix to day number (1st, 2nd, 3rd, etc.)
60
+ day = now.day
61
+ if 4 <= day <= 20 or 24 <= day <= 30:
62
+ suffix = "th"
63
+ else:
64
+ suffix = ["st", "nd", "rd"][day % 10 - 1]
65
+ return now.strftime(f"%A, %B {day}{suffix}, %Y at %-I:%M %p")
66
+
67
+
68
+ def bootstrap_ci(data, n_bootstrap=10000, ci=95):
69
+ """Calculate bootstrap confidence intervals."""
70
+ bootstrap_samples = []
71
+ for _ in range(n_bootstrap):
72
+ bootstrap_samples.append(np.mean(random.choices(data, k=len(data))))
73
+ lower_bound = np.percentile(bootstrap_samples, (100 - ci) / 2)
74
+ upper_bound = np.percentile(bootstrap_samples, 100 - (100 - ci) / 2)
75
+ return lower_bound, upper_bound
76
+
77
+
78
+ def calculate_win_rates(json_data):
79
+ """Calculate win rates from JSON data."""
80
+ data = json.loads(json_data)
81
+
82
+ model_wins = defaultdict(int)
83
+ total_matches = defaultdict(int)
84
+ total_votes = 0
85
+
86
+ for value in data["_default"].values():
87
+ total_votes += 1
88
+ if value["outcome"] == 0:
89
+ model_wins[value["model_a"]] += 1
90
+ elif value["outcome"] == 1:
91
+ model_wins[value["model_b"]] += 1
92
+ elif value["outcome"] == 0.5:
93
+ model_wins[value["model_a"]] += 0.5
94
+ model_wins[value["model_b"]] += 0.5
95
+ total_matches[value["model_a"]] += 1
96
+ total_matches[value["model_b"]] += 1
97
+
98
+ per_model_wins = {}
99
+ for model, wins in model_wins.items():
100
+ win_rate = wins / total_matches[model]
101
+ wins_data = [1] * int(wins) + [0] * int(total_matches[model] - wins)
102
+ if int(wins) != wins:
103
+ wins_data += [0.5]
104
+ lower, upper = bootstrap_ci(wins_data)
105
+ per_model_wins[model] = {
106
+ "model": model,
107
+ "win_rate": win_rate,
108
+ "95_lower": (win_rate - lower),
109
+ "95_upper": (upper - win_rate),
110
+ }
111
+ df = pd.DataFrame.from_dict(per_model_wins).T
112
+
113
+ return df, total_votes
114
+
115
+
116
+ def create_win_rate_plot(wins_df):
117
+ """Create win rate plot using Plotly."""
118
+ wins_df["Source"] = wins_df["Source"].astype(str)
119
+ wins_df = wins_df.sort_values(by=["Source", "win_rate"], ascending=False)
120
+ wins_df["model"] = wins_df["model"].apply(lambda x: NAME_MAPPING.get(x, x))
121
+
122
+ fig = px.bar(
123
+ wins_df,
124
+ x="model",
125
+ y="win_rate",
126
+ error_y="95_upper",
127
+ error_y_minus="95_lower",
128
+ color="model",
129
+ color_discrete_sequence=COLORS,
130
+ animation_group="model",
131
+ animation_frame="Source",
132
+ )
133
+
134
+ fig.update_traces(
135
+ hovertemplate="<b>%{x}</b><br>" + "Win Rate: %{y}" + "<extra></extra>",
136
+ )
137
+
138
+ fig.update_layout(
139
+ autosize=True,
140
+ showlegend=False,
141
+ plot_bgcolor="white",
142
+ title={
143
+ "text": "Talk Arena Live Win Rates<br>with 95% Confidence Intervals",
144
+ "y": 0.95,
145
+ "x": 0.5,
146
+ "xanchor": "center",
147
+ "yanchor": "top",
148
+ },
149
+ xaxis_title="Model",
150
+ yaxis_title="Win Rate (%)",
151
+ bargap=0.2,
152
+ yaxis=dict(
153
+ tickformat=".0%", tickmode="auto", range=[0, 1.01], gridcolor="#C9CCD1", griddash="dash", gridwidth=2
154
+ ),
155
+ legend=dict(
156
+ orientation="h", # Make legend horizontal
157
+ yanchor="bottom",
158
+ y=-0.5, # Position below plot
159
+ xanchor="center",
160
+ x=0.5, # Center horizontally
161
+ bgcolor="rgba(255, 255, 255, 0.8)",
162
+ bordercolor="#C9CCD1",
163
+ borderwidth=1,
164
+ ),
165
+ margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
166
+ hoverlabel=dict(bgcolor="white", font_size=14, bordercolor="gray"),
167
+ )
168
+
169
+ fig.update_xaxes(showgrid=False)
170
+
171
+ return fig
172
+
173
+
174
+ # Bradley-Terry Model Functions
175
+ def load_live_votes(json_str: str) -> pd.DataFrame:
176
+ """Load and preprocess live votes data from JSON string."""
177
+ data = json.loads(json_str)
178
+ df = pd.DataFrame.from_dict(data["_default"], orient="index")
179
+ df["winner"] = df["outcome"].map({1: "model_b", 0: "model_a", 0.5: "tie"})
180
+ return df
181
+
182
+
183
+ def preprocess_for_bt(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str], np.ndarray]:
184
+ """Preprocess data for Bradley-Terry model fitting."""
185
+ all_models = pd.concat([df["model_a"], df["model_b"]]).unique()
186
+ model_to_idx = {model: idx for idx, model in enumerate(all_models)}
187
+
188
+ matchups = np.array([[model_to_idx[row.model_a], model_to_idx[row.model_b]] for _, row in df.iterrows()])
189
+
190
+ outcomes = np.array(
191
+ [1.0 if row.winner == "model_a" else (0.5 if row.winner == "tie" else 0.0) for _, row in df.iterrows()]
192
+ )
193
+
194
+ unique_matches = np.column_stack([matchups, outcomes])
195
+ unique_matches, weights = np.unique(unique_matches, return_counts=True, axis=0)
196
+
197
+ return (unique_matches[:, :2].astype(np.int32), unique_matches[:, 2], list(all_models), weights.astype(np.float64))
198
+
199
+
200
+ def bt_loss_and_grad(
201
+ ratings: np.ndarray, matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, alpha: float = 1.0
202
+ ) -> Tuple[float, np.ndarray]:
203
+ """Compute Bradley-Terry loss and gradient."""
204
+ matchup_ratings = ratings[matchups]
205
+ logits = alpha * (matchup_ratings[:, 0] - matchup_ratings[:, 1])
206
+ probs = expit(logits)
207
+
208
+ loss = -((np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights).sum()
209
+
210
+ matchups_grads = -alpha * (outcomes - probs) * weights
211
+ model_grad = np.zeros_like(ratings)
212
+ np.add.at(model_grad, matchups[:, [0, 1]], matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64))
213
+
214
+ return loss, model_grad
215
+
216
+
217
+ def fit_bt(
218
+ matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, n_models: int, alpha: float, tol: float = 1e-6
219
+ ) -> np.ndarray:
220
+ """Fit Bradley-Terry model using L-BFGS-B optimization."""
221
+ initial_ratings = np.zeros(n_models, dtype=np.float64)
222
+
223
+ result = minimize(
224
+ fun=bt_loss_and_grad,
225
+ x0=initial_ratings,
226
+ args=(matchups, outcomes, weights, alpha),
227
+ jac=True,
228
+ method="L-BFGS-B",
229
+ options={"disp": False, "maxiter": 100, "gtol": tol},
230
+ )
231
+
232
+ return result["x"]
233
+
234
+
235
+ def scale_and_offset(
236
+ ratings: np.ndarray, models: List[str], scale: float = 400, init_rating: float = 1000
237
+ ) -> np.ndarray:
238
+ """Scale ratings to familiar Elo-like scale."""
239
+ scaled_ratings = (ratings * scale) + init_rating
240
+ return scaled_ratings
241
+
242
+
243
+ def compute_bootstrap_bt(
244
+ data: str,
245
+ num_round: int = 100,
246
+ base: float = 10.0,
247
+ scale: float = 400.0,
248
+ init_rating: float = 1000.0,
249
+ tol: float = 1e-6,
250
+ ) -> pd.DataFrame:
251
+ """Compute bootstrap Bradley-Terry ratings from live votes data."""
252
+ df = load_live_votes(data)
253
+ matchups, outcomes, models, weights = preprocess_for_bt(df)
254
+
255
+ rng = np.random.default_rng(seed=0)
256
+ total_matches = len(df)
257
+ idxs = rng.multinomial(n=total_matches, pvals=weights / weights.sum(), size=num_round)
258
+ boot_weights = idxs.astype(np.float64) / total_matches
259
+
260
+ ratings_list = []
261
+ for sample_weights in boot_weights:
262
+ ratings = fit_bt(
263
+ matchups=matchups,
264
+ outcomes=outcomes,
265
+ weights=sample_weights,
266
+ n_models=len(models),
267
+ alpha=np.log(base),
268
+ tol=tol,
269
+ )
270
+ scaled_ratings = scale_and_offset(ratings=ratings, models=models, scale=scale, init_rating=init_rating)
271
+ ratings_list.append(scaled_ratings)
272
+
273
+ df_ratings = pd.DataFrame(ratings_list, columns=models)
274
+ return df_ratings[df_ratings.median().sort_values(ascending=False).index]
275
+
276
+
277
+ def create_bt_plot(bootstrap_ratings):
278
+ """Create Bradley-Terry ratings plot using Plotly."""
279
+ melted_bootstrap = bootstrap_ratings.melt(id_vars=["Source", "level_1"], var_name="Model", value_name="BT")
280
+ melted_bootstrap = melted_bootstrap.dropna()
281
+ melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "Model", "BT"], ascending=False)
282
+ # Pretty Names
283
+ melted_bootstrap["Model"] = melted_bootstrap["Model"].apply(lambda x: NAME_MAPPING.get(x, x))
284
+ # Compression for Client Side
285
+ melted_bootstrap["BT"] = melted_bootstrap["BT"].apply(lambda x: int(x))
286
+ min_samp = melted_bootstrap[melted_bootstrap["BT"] > 0]["BT"].min()
287
+ max_samp = melted_bootstrap["BT"].max()
288
+ idx_keep = list(range(0, len(melted_bootstrap), 10))
289
+ melted_bootstrap = melted_bootstrap.iloc[idx_keep]
290
+ melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "BT"], ascending=False)
291
+ fig = px.violin(
292
+ melted_bootstrap,
293
+ x="Model",
294
+ y="BT",
295
+ color="Model",
296
+ animation_group="Model",
297
+ animation_frame="Source",
298
+ color_discrete_sequence=COLORS,
299
+ )
300
+
301
+ fig.update_layout(
302
+ autosize=True,
303
+ showlegend=False,
304
+ plot_bgcolor="white",
305
+ title={
306
+ "text": "Talk Arena Live Bradley-Terry Ratings<br>with Bootstrapped Variance",
307
+ "y": 0.9,
308
+ "x": 0.5,
309
+ "xanchor": "center",
310
+ "yanchor": "top",
311
+ },
312
+ xaxis_title="Model",
313
+ yaxis_title="Rating",
314
+ yaxis=dict(gridcolor="#C9CCD1", range=[min_samp - 10, max_samp + 10], griddash="dash"),
315
+ legend=dict(
316
+ orientation="h", # Make legend horizontal
317
+ yanchor="bottom",
318
+ y=-0.5, # Position below plot
319
+ xanchor="center",
320
+ x=0.5, # Center horizontally
321
+ bgcolor="rgba(255, 255, 255, 0.8)",
322
+ bordercolor="#C9CCD1",
323
+ borderwidth=1,
324
+ ),
325
+ margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
326
+ )
327
+
328
+ fig.update_xaxes(showgrid=False)
329
+ fig.update_yaxes(showgrid=True, gridwidth=2)
330
+
331
+ return fig
332
+
333
+
334
+ def get_wr_plot():
335
+ jrep = json.loads(pio.to_json(WR_PLOT))
336
+ for step in jrep["layout"]["sliders"][0]["steps"]:
337
+ step["args"][1]["frame"]["duration"] = 500
338
+ step["args"][1]["transition"]["duration"] = 500
339
+ jrep["layout"]["updatemenus"] = []
340
+ jrep["layout"]["sliders"][0]["len"] = 0.8
341
+ jrep["layout"]["sliders"][0]["pad"] = {}
342
+ return json.dumps(jrep)
343
+
344
+
345
+ def get_bt_plot():
346
+ jrep = json.loads(pio.to_json(BT_PLOT))
347
+ for step in jrep["layout"]["sliders"][0]["steps"]:
348
+ step["args"][1]["frame"]["duration"] = 500
349
+ step["args"][1]["transition"]["duration"] = 500
350
+ jrep["layout"]["updatemenus"] = []
351
+ jrep["layout"]["sliders"][0]["len"] = 0.8
352
+ jrep["layout"]["sliders"][0]["pad"] = {}
353
+ return json.dumps(jrep)
354
+
355
+
356
+ def get_update_time():
357
+ return UPDATE_TIME
358
+
359
+
360
+ def viz_factory(force=False):
361
+ def process_and_visualize():
362
+ """Main function to process JSON data and create visualizations."""
363
+ global WR_PLOT, BT_PLOT, UPDATE_TIME
364
+ if WR_PLOT is not None and BT_PLOT is not None and not force:
365
+ return WR_PLOT, BT_PLOT, UPDATE_TIME
366
+ try:
367
+ # Read JSON data
368
+ pub_json_data = open("/home/wheld3/talk-arena/live_votes.json", "r").read()
369
+ prolific_json_data = open("/home/wheld3/talk-arena/prolific_votes.json", "r").read()
370
+ merged_json_data = json.dumps(
371
+ {"_default": {**json.loads(pub_json_data)["_default"], **json.loads(prolific_json_data)["_default"]}}
372
+ )
373
+ # Calculate win rates and create win rate plot
374
+ pub_win_rates, pub_votes = calculate_win_rates(pub_json_data)
375
+ pro_win_rates, pro_votes = calculate_win_rates(prolific_json_data)
376
+ total_win_rates, total_votes = calculate_win_rates(merged_json_data)
377
+ all_models = total_win_rates["model"].unique()
378
+ pro_models = pro_win_rates["model"].unique()
379
+ for model in all_models:
380
+ if model not in pro_models:
381
+ new_index = len(pro_win_rates)
382
+ pro_win_rates.loc[new_index] = [model, -0.1, -0.1, -0.2]
383
+
384
+ win_rates = (
385
+ pd.concat([pub_win_rates, pro_win_rates, total_win_rates], keys=["Public", "Prolific", "Total"])
386
+ .reset_index()
387
+ .rename(columns={"level_0": "Source"})
388
+ )
389
+ WR_PLOT = create_win_rate_plot(win_rates)
390
+
391
+ # Calculate Bradley-Terry ratings and create BT plot
392
+ pub_bootstrap_ratings = compute_bootstrap_bt(pub_json_data, num_round=10000)
393
+ pro_bootstrap_ratings = compute_bootstrap_bt(prolific_json_data, num_round=10000)
394
+ total_bootstrap_ratings = compute_bootstrap_bt(merged_json_data, num_round=10000)
395
+ for model in all_models:
396
+ if model not in pro_models:
397
+ pro_bootstrap_ratings[model] = pro_bootstrap_ratings["diva_3_8b"] * -1
398
+
399
+ bootstrap_ratings = (
400
+ pd.concat(
401
+ [pub_bootstrap_ratings, pro_bootstrap_ratings, total_bootstrap_ratings],
402
+ keys=["Public", "Prolific", "Total"],
403
+ )
404
+ .reset_index()
405
+ .rename(columns={"level_0": "Source"})
406
+ )
407
+ BT_PLOT = create_bt_plot(bootstrap_ratings)
408
+ UPDATE_TIME = gr.Markdown(
409
+ value=textwrap.dedent(
410
+ f"""
411
+ <h4 class="nx-font-semibold nx-tracking-tight nx-text-slate-900 dark:nx-text-slate-100 nx-text-xl">Last Refresh: {get_aesthetic_timestamp()} PST</h4>
412
+ <h6 class="nx-font-semibold nx-tracking-tight nx-text-slate-900 dark:nx-text-slate-100 nx-text-base">Total Votes: {total_votes}, Public Votes: {pub_votes}, Prolific Votes: {pro_votes}</h6>
413
+ """
414
+ )
415
+ )
416
+ return WR_PLOT, BT_PLOT, UPDATE_TIME
417
+
418
+ except Exception as e:
419
+ raise gr.Error(f"Error processing file: {str(e)}")
420
+
421
+ return process_and_visualize
422
+
423
+
424
+ theme = gr.themes.Soft(
425
+ primary_hue=gr.themes.Color(
426
+ c100="#82000019",
427
+ c200="#82000033",
428
+ c300="#8200004c",
429
+ c400="#82000066",
430
+ c50="#8200007f",
431
+ c500="#8200007f",
432
+ c600="#82000099",
433
+ c700="#820000b2",
434
+ c800="#820000cc",
435
+ c900="#820000e5",
436
+ c950="#820000f2",
437
+ ),
438
+ secondary_hue="rose",
439
+ neutral_hue="stone",
440
+ )
441
+
442
+ # Create Gradio interface
443
+ with gr.Blocks(title="Talk Arena Leaderboard Analysis", theme=theme) as demo:
444
+ viz_factory(force=True)()
445
+ last_updated = UPDATE_TIME
446
+ with gr.Row():
447
+ bt_plot = gr.Plot(label="Bradley-Terry Ratings", value=BT_PLOT)
448
+ with gr.Row():
449
+ win_rate_plot = gr.Plot(label="Win Rates", value=WR_PLOT)
450
+
451
+ d1 = gr.Textbox(visible=False)
452
+ demo.load(
453
+ fn=viz_factory(force=False), inputs=[], outputs=[win_rate_plot, bt_plot, last_updated], show_progress="minimal"
454
+ )
455
+ demo.load(fn=get_wr_plot, inputs=[], outputs=[d1])
456
+ demo.load(fn=get_bt_plot, inputs=[], outputs=[d1])
457
+ demo.load(fn=get_update_time, inputs=[], outputs=[d1])
458
+
459
+ if __name__ == "__main__":
460
+ scheduler = BackgroundScheduler()
461
+ scheduler.add_job(func=viz_factory(force=True), trigger="interval", seconds=300)
462
+ scheduler.start()
463
+ demo.queue(default_concurrency_limit=10, api_open=True).launch(share=True, server_port=8004, node_port=8005)
talk_arena/streaming_helpers.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import json
4
+ import os
5
+ from collections import defaultdict
6
+ from pathlib import Path
7
+
8
+ import google.generativeai as genai
9
+ import gradio as gr
10
+ import librosa
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+ import xxhash
15
+ from datasets import Audio
16
+ from openai import AsyncOpenAI
17
+ from transformers import AutoModel, AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
18
+ from transformers.generation import GenerationConfig
19
+
20
+
21
+ def _get_prompt_for_model_name(model_id):
22
+ prompt_dict = defaultdict(lambda: "You are a helpful assistant. Respond conversationally to the speech provided.")
23
+ # Requested Overrides
24
+ prompt_dict["scb10x/llama-3-typhoon-audio-8b-2411"] = (
25
+ "You are a helpful assistant. Respond conversationally to the speech provided in the language it is spoken in."
26
+ )
27
+ return prompt_dict[model_id]
28
+
29
+
30
+ def _get_config_for_model_name(model_id):
31
+ if "API_MODEL_CONFIG" in os.environ:
32
+ return json.loads(os.environ["API_MODEL_CONFIG"])[model_id]
33
+ return {
34
+ "pipeline/meta-llama/Meta-Llama-3-8B-Instruct": {"base_url": "http://localhost:8001/v1", "api_key": "empty"},
35
+ "scb10x/llama-3-typhoon-audio-8b-2411": {
36
+ "base_url": "http://localhost:8002/v1",
37
+ "api_key": "empty",
38
+ },
39
+ "WillHeld/DiVA-llama-3-v0-8b": {
40
+ "base_url": "http://localhost:8003/v1",
41
+ "api_key": "empty",
42
+ },
43
+ "Qwen/Qwen2-Audio-7B-Instruct": {
44
+ "base_url": "http://localhost:8004/v1",
45
+ "api_key": "empty",
46
+ },
47
+ }[model_id]
48
+
49
+
50
+ def gradio_gen_factory(streaming_fn, model_name, anonymous):
51
+ async def gen_from(audio_input, order):
52
+ with torch.no_grad():
53
+ prev_resp = ""
54
+ async for resp in streaming_fn(audio_input):
55
+ for char in range(len(prev_resp), len(resp)):
56
+ my_resp = gr.Textbox(
57
+ value=resp[: char + 1],
58
+ info="",
59
+ visible=True,
60
+ label=model_name if not anonymous else f"Model {order+1}",
61
+ elem_classes="lam-response-box",
62
+ )
63
+ yield my_resp
64
+ await asyncio.sleep(0.001)
65
+ prev_resp = resp
66
+
67
+ return gen_from
68
+
69
+
70
+ def gemini_streaming(model_id):
71
+ genai.configure(api_key=os.environ["GEMINI_API_KEY"])
72
+ resampler = Audio(sampling_rate=16_000)
73
+ model = genai.GenerativeModel(model_id)
74
+
75
+ async def get_chat_response(audio_input):
76
+ if audio_input is None:
77
+ raise StopAsyncIteration("")
78
+ sr, y = audio_input
79
+ x = xxhash.xxh32(bytes(y)).hexdigest()
80
+ y = y.astype(np.float32)
81
+ y /= np.max(np.abs(y))
82
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
83
+ sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
84
+ prompt = "You are a helpful assistant. Respond conversationally to the speech provided."
85
+ inputs = [prompt, {"mime_type": "audio/wav", "data": Path(f"{x}.wav").read_bytes()}]
86
+ text_response = []
87
+ responses = model.generate_content(inputs, stream=True)
88
+ for chunk in responses:
89
+ text_response.append(chunk.text)
90
+ yield "".join(text_response)
91
+ os.remove(f"{x}.wav")
92
+
93
+ return get_chat_response, model
94
+
95
+
96
+ def gpt4o_streaming(model_id):
97
+ client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
98
+ resampler = Audio(sampling_rate=16_000)
99
+
100
+ async def get_chat_response(audio_input):
101
+ if audio_input is None:
102
+ raise StopAsyncIteration("")
103
+ sr, y = audio_input
104
+ x = xxhash.xxh32(bytes(y)).hexdigest()
105
+ y = y.astype(np.float32)
106
+ y /= np.max(np.abs(y))
107
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
108
+ sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
109
+ with open(f"{x}.wav", "rb") as wav_file:
110
+ wav_data = wav_file.read()
111
+ encoded_string = base64.b64encode(wav_data).decode("utf-8")
112
+ prompt = "You are a helpful assistant. Respond conversationally to the speech provided."
113
+ try:
114
+ completion = await client.chat.completions.create(
115
+ model="gpt-4o-audio-preview",
116
+ modalities=["text", "audio"],
117
+ audio={"voice": "alloy", "format": "wav"},
118
+ messages=[
119
+ {
120
+ "role": "user",
121
+ "content": [
122
+ {"type": "text", "text": prompt},
123
+ {"type": "input_audio", "input_audio": {"data": encoded_string, "format": "wav"}},
124
+ ],
125
+ },
126
+ ],
127
+ )
128
+ os.remove(f"{x}.wav")
129
+ yield completion.choices[0].message.audio.transcript
130
+ except:
131
+ raise StopAsyncIteration("error")
132
+
133
+ return get_chat_response, client
134
+
135
+
136
+ async def llm_streaming(model_id: str, prompt: str):
137
+ if "gpt" in model_id:
138
+ client = AsyncOpenAI()
139
+ else:
140
+ client = AsyncOpenAI(**_get_config_for_model_name(model_id))
141
+ try:
142
+ completion = await client.chat.completions.create(
143
+ model=model_id,
144
+ messages=[
145
+ {"role": "system", "content": "You are helpful assistant."},
146
+ {
147
+ "role": "user",
148
+ "content": prompt,
149
+ },
150
+ ],
151
+ stream=True,
152
+ )
153
+ text_response = []
154
+ async for chunk in completion:
155
+ if len(chunk.choices) > 0:
156
+ text_response.append(chunk.choices[0].delta.content)
157
+ yield "".join(text_response)
158
+ except:
159
+ raise StopAsyncIteration("error")
160
+
161
+
162
+ def asr_streaming(model_id, asr_pipe):
163
+ resampler = Audio(sampling_rate=16_000)
164
+
165
+ async def pipelined(audio_input):
166
+ if audio_input is None:
167
+ raise StopAsyncIteration("")
168
+ sr, y = audio_input
169
+ x = xxhash.xxh32(bytes(y)).hexdigest()
170
+ y = y.astype(np.float32)
171
+ y /= np.max(np.abs(y))
172
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
173
+ sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
174
+ text = await asyncio.to_thread(
175
+ asr_pipe(f"{x}.wav", generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
176
+ )
177
+ os.remove(f"{x}.wav")
178
+ async for response in llm_streaming(model_id, prompt=text):
179
+ yield response
180
+
181
+ return pipelined
182
+
183
+
184
+ def api_streaming(model_id):
185
+ client = AsyncOpenAI(**_get_config_for_model_name(model_id))
186
+ resampler = Audio(sampling_rate=16_000)
187
+
188
+ async def get_chat_response(audio_input):
189
+ if audio_input is None:
190
+ raise StopAsyncIteration("")
191
+ sr, y = audio_input
192
+ x = xxhash.xxh32(bytes(y)).hexdigest()
193
+ y = y.astype(np.float32)
194
+ y /= np.max(np.abs(y))
195
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
196
+ sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
197
+ with open(f"{x}.wav", "rb") as wav_file:
198
+ wav_data = wav_file.read()
199
+ encoded_string = base64.b64encode(wav_data).decode("utf-8")
200
+ try:
201
+ prompt = _get_prompt_for_model_name(model_id)
202
+ completion = await client.chat.completions.create(
203
+ model=model_id,
204
+ messages=[
205
+ {
206
+ "role": "user",
207
+ "content": [
208
+ {"type": "text", "text": prompt},
209
+ {"type": "audio", "audio_url": "data:audio/wav;base64," + encoded_string},
210
+ ],
211
+ },
212
+ ],
213
+ stream=True,
214
+ )
215
+ text_response = []
216
+ async for chunk in completion:
217
+ if len(chunk.choices) > 0:
218
+ text_response.append(chunk.choices[0].delta.content)
219
+ yield "".join(text_response)
220
+ os.remove(f"{x}.wav")
221
+ except:
222
+ print(f"error for {model_id}")
223
+ raise StopAsyncIteration(f"error for {model_id}")
224
+
225
+ return get_chat_response, client
226
+
227
+
228
+ # Local Hosting Utilities
229
+
230
+
231
+ def diva_streaming(diva_model_str):
232
+ diva_model = AutoModel.from_pretrained(diva_model_str, trust_remote_code=True, device_map="balanced_low_0")
233
+ resampler = Audio(sampling_rate=16_000)
234
+
235
+ async def diva_audio(audio_input, do_sample=False, temperature=0.001):
236
+ sr, y = audio_input
237
+ y = y.astype(np.float32)
238
+ y /= np.max(np.abs(y))
239
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
240
+ stream = diva_model.generate_stream(
241
+ a["array"],
242
+ (
243
+ "You are a helpful assistant The user is talking to you with their voice and you are responding with"
244
+ " text."
245
+ ),
246
+ do_sample=do_sample,
247
+ max_new_tokens=256,
248
+ )
249
+ for text in stream:
250
+ yield text
251
+
252
+ return diva_audio, diva_model
253
+
254
+
255
+ def qwen2_streaming(qwen2_model_str):
256
+ resampler = Audio(sampling_rate=16_000)
257
+ qwen2_processor = AutoProcessor.from_pretrained(qwen2_model_str)
258
+ qwen2_model = Qwen2AudioForConditionalGeneration.from_pretrained(qwen2_model_str, device_map="auto")
259
+ qwen2_model.generation_config = GenerationConfig.from_pretrained(
260
+ qwen2_model_str,
261
+ trust_remote_code=True,
262
+ do_sample=False,
263
+ top_k=50,
264
+ top_p=1.0,
265
+ )
266
+
267
+ async def qwen2_audio(audio_input, do_sample=False, temperature=0.001):
268
+ if audio_input is None:
269
+ raise StopAsyncIteration("")
270
+ sr, y = audio_input
271
+ x = xxhash.xxh32(bytes(y)).hexdigest()
272
+ y = y.astype(np.float32)
273
+ y /= np.max(np.abs(y))
274
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
275
+ sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
276
+ conversation = [
277
+ {"role": "system", "content": "You are a helpful assistant."},
278
+ {
279
+ "role": "user",
280
+ "content": [
281
+ {
282
+ "type": "audio",
283
+ "audio_url": f"{x}.wav",
284
+ },
285
+ ],
286
+ },
287
+ ]
288
+ text = qwen2_processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
289
+ audios = [librosa.load(f"{x}.wav", sr=qwen2_processor.feature_extractor.sampling_rate)[0]]
290
+ inputs = qwen2_processor(text=text, audios=audios, return_tensors="pt", padding=True)
291
+ streamer = TextIteratorStreamer(qwen2_processor)
292
+ generation_task = asyncio.create_task(qwen2_model.generate(**inputs, streamer=streamer, max_length=256))
293
+
294
+ generated_text = ""
295
+ async for new_text in streamer:
296
+ generated_text += new_text
297
+ yield generated_text.split("<|im_start|>assistant\n")[-1].replace("<|im_end|>", "")
298
+
299
+ await generation_task
300
+ os.remove(f"{x}.wav")
301
+
302
+ return qwen2_audio, qwen2_model
303
+
304
+
305
+ def typhoon_streaming(typhoon_model_str, device="cuda:0"):
306
+ resampler = Audio(sampling_rate=16_000)
307
+ typhoon_model = AutoModel.from_pretrained(typhoon_model_str, torch_dtype=torch.float16, trust_remote_code=True)
308
+ tokenizer = typhoon_model.llama_tokenizer
309
+ typhoon_model.to(device)
310
+ typhoon_model.eval()
311
+ prompt_pattern = (
312
+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<Speech><SpeechHere></Speech>"
313
+ " {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
314
+ )
315
+ prompt = (
316
+ "You are a helpful assistant. Respond conversationally to the speech provided in the language it is spoken in."
317
+ )
318
+
319
+ async def typhoon_audio(audio_input, do_sample=False, temperature=0.001):
320
+ if audio_input == None:
321
+ raise StopAsyncIteration("")
322
+ sr, y = audio_input
323
+ x = xxhash.xxh32(bytes(y)).hexdigest()
324
+ y = y.astype(np.float32)
325
+ y /= np.max(np.abs(y))
326
+ a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
327
+ streamer = TextIteratorStreamer(tokenizer)
328
+ generation_task = asyncio.create_task(
329
+ typhoon_model.generate(
330
+ audio=a["array"],
331
+ prompt=prompt,
332
+ prompt_pattern=prompt_pattern,
333
+ device=device,
334
+ do_sample=False,
335
+ max_length=1200,
336
+ num_beams=1,
337
+ streamer=streamer, # supports TextIteratorStreamer
338
+ )
339
+ )
340
+ generated_text = ""
341
+ async for new_text in streamer:
342
+ generated_text += new_text
343
+ yield generated_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].replace(
344
+ "<|eot_id|>", ""
345
+ )
346
+ await generation_task
347
+
348
+ return typhoon_audio, typhoon_model
talk_arena/styles.css ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @media (max-width: 768px) {
2
+ .lam-response-box {
3
+ max-height: 230px;
4
+ }
5
+
6
+ .lam-response-box > label {
7
+ max-height: 100%;
8
+ }
9
+
10
+ .lam-response-box > label > div > textarea{
11
+ max-height: 100%;
12
+ height: 100% !important;
13
+ }
14
+ }
15
+
16
+ @media (min-width: 769px) {
17
+ .lam-response-box {
18
+ max-height: 40vh;
19
+ }
20
+
21
+ .lam-response-box > label > div > textarea{
22
+ max-height: calc(40vh - 50px) !important;
23
+ }
24
+ }
25
+
talk_arena/viz/core.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import textwrap
4
+ from collections import defaultdict
5
+ from datetime import datetime
6
+ from typing import Dict, List, Tuple
7
+ from zoneinfo import ZoneInfo
8
+
9
+ import numpy as np
10
+ import pandas as pd
11
+ import plotly.express as px
12
+ import plotly.io as pio
13
+ from scipy.optimize import minimize
14
+ from scipy.special import expit
15
+
16
+ # Constants
17
+ COLORS = [
18
+ "#1B7FFF",
19
+ "#F07D1A",
20
+ "#BA24C7",
21
+ "#FE42C7",
22
+ "#0D4B7C",
23
+ "#0EAC96",
24
+ "#AA7CFF",
25
+ "#B50550",
26
+ "#009EEB",
27
+ "#220B55",
28
+ "#7B3301",
29
+ ]
30
+ NAME_MAPPING = {
31
+ "gemini_2f": "Gemini 2.0 (Exp)",
32
+ "diva_3_8b": "DiVA Llama 3 8B",
33
+ "qwen2": "Qwen 2 Audio",
34
+ "pipe_l3.0": "Pipelined Llama 3 8B",
35
+ "gemini_1.5f": "Gemini 1.5 Flash",
36
+ "gpt4o": "GPT-4o",
37
+ "gemini_1.5p": "Gemini 1.5 Pro",
38
+ "typhoon_audio": "Typhoon Audio",
39
+ }
40
+
41
+ def get_aesthetic_timestamp():
42
+ """
43
+ Returns a beautifully formatted timestamp in the format:
44
+ 'Tuesday, December 10th, 2024 at 3:45 PM'
45
+ """
46
+ # Get timezone object for PST
47
+ pst = ZoneInfo("America/Los_Angeles")
48
+
49
+ # Get current time in PST
50
+ now = datetime.now(pst)
51
+
52
+ # Add suffix to day number (1st, 2nd, 3rd, etc.)
53
+ day = now.day
54
+ if 4 <= day <= 20 or 24 <= day <= 30:
55
+ suffix = "th"
56
+ else:
57
+ suffix = ["st", "nd", "rd"][day % 10 - 1]
58
+ return now.strftime(f"%A, %B {day}{suffix}, %Y at %-I:%M %p")
59
+
60
+
61
+ def bootstrap_ci(data, n_bootstrap=10000, ci=95):
62
+ """Calculate bootstrap confidence intervals."""
63
+ bootstrap_samples = []
64
+ for _ in range(n_bootstrap):
65
+ bootstrap_samples.append(np.mean(random.choices(data, k=len(data))))
66
+ lower_bound = np.percentile(bootstrap_samples, (100 - ci) / 2)
67
+ upper_bound = np.percentile(bootstrap_samples, 100 - (100 - ci) / 2)
68
+ return lower_bound, upper_bound
69
+
70
+
71
+ def calculate_win_rates(json_data):
72
+ """Calculate win rates from JSON data."""
73
+ data = json.loads(json_data)
74
+
75
+ model_wins = defaultdict(int)
76
+ total_matches = defaultdict(int)
77
+ total_votes = 0
78
+
79
+ for value in data["_default"].values():
80
+ total_votes += 1
81
+ if value["outcome"] == 0:
82
+ model_wins[value["model_a"]] += 1
83
+ elif value["outcome"] == 1:
84
+ model_wins[value["model_b"]] += 1
85
+ elif value["outcome"] == 0.5:
86
+ model_wins[value["model_a"]] += 0.5
87
+ model_wins[value["model_b"]] += 0.5
88
+ total_matches[value["model_a"]] += 1
89
+ total_matches[value["model_b"]] += 1
90
+
91
+ per_model_wins = {}
92
+ for model, wins in model_wins.items():
93
+ win_rate = wins / total_matches[model]
94
+ wins_data = [1] * int(wins) + [0] * int(total_matches[model] - wins)
95
+ if int(wins) != wins:
96
+ wins_data += [0.5]
97
+ lower, upper = bootstrap_ci(wins_data)
98
+ per_model_wins[model] = {
99
+ "model": model,
100
+ "win_rate": win_rate,
101
+ "95_lower": (win_rate - lower),
102
+ "95_upper": (upper - win_rate),
103
+ }
104
+ df = pd.DataFrame.from_dict(per_model_wins).T
105
+
106
+ return df, total_votes
107
+
108
+
109
+ def create_win_rate_plot(wins_df):
110
+ """Create win rate plot using Plotly."""
111
+ wins_df["Source"] = wins_df["Source"].astype(str)
112
+ wins_df = wins_df.sort_values(by=["Source", "win_rate"], ascending=False)
113
+ wins_df["model"] = wins_df["model"].apply(lambda x: NAME_MAPPING.get(x, x))
114
+
115
+ fig = px.bar(
116
+ wins_df,
117
+ x="model",
118
+ y="win_rate",
119
+ error_y="95_upper",
120
+ error_y_minus="95_lower",
121
+ color="model",
122
+ color_discrete_sequence=COLORS,
123
+ animation_group="model",
124
+ animation_frame="Source",
125
+ )
126
+
127
+ fig.update_traces(
128
+ hovertemplate="<b>%{x}</b><br>" + "Win Rate: %{y}" + "<extra></extra>",
129
+ )
130
+
131
+ fig.update_layout(
132
+ autosize=True,
133
+ showlegend=False,
134
+ plot_bgcolor="white",
135
+ title={
136
+ "text": "Talk Arena Live Win Rates<br>with 95% Confidence Intervals",
137
+ "y": 0.95,
138
+ "x": 0.5,
139
+ "xanchor": "center",
140
+ "yanchor": "top",
141
+ },
142
+ xaxis_title="Model",
143
+ yaxis_title="Win Rate (%)",
144
+ bargap=0.2,
145
+ yaxis=dict(
146
+ tickformat=".0%", tickmode="auto", range=[0, 1.01], gridcolor="#C9CCD1", griddash="dash", gridwidth=2
147
+ ),
148
+ legend=dict(
149
+ orientation="h", # Make legend horizontal
150
+ yanchor="bottom",
151
+ y=-0.5, # Position below plot
152
+ xanchor="center",
153
+ x=0.5, # Center horizontally
154
+ bgcolor="rgba(255, 255, 255, 0.8)",
155
+ bordercolor="#C9CCD1",
156
+ borderwidth=1,
157
+ ),
158
+ margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
159
+ hoverlabel=dict(bgcolor="white", font_size=14, bordercolor="gray"),
160
+ )
161
+
162
+ fig.update_xaxes(showgrid=False)
163
+
164
+ return fig
165
+
166
+
167
+ # Bradley-Terry Model Functions
168
+ def load_live_votes(json_str: str) -> pd.DataFrame:
169
+ """Load and preprocess live votes data from JSON string."""
170
+ data = json.loads(json_str)
171
+ df = pd.DataFrame.from_dict(data["_default"], orient="index")
172
+ df["winner"] = df["outcome"].map({1: "model_b", 0: "model_a", 0.5: "tie"})
173
+ return df
174
+
175
+
176
+ def preprocess_for_bt(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str], np.ndarray]:
177
+ """Preprocess data for Bradley-Terry model fitting."""
178
+ all_models = pd.concat([df["model_a"], df["model_b"]]).unique()
179
+ model_to_idx = {model: idx for idx, model in enumerate(all_models)}
180
+
181
+ matchups = np.array([[model_to_idx[row.model_a], model_to_idx[row.model_b]] for _, row in df.iterrows()])
182
+
183
+ outcomes = np.array(
184
+ [1.0 if row.winner == "model_a" else (0.5 if row.winner == "tie" else 0.0) for _, row in df.iterrows()]
185
+ )
186
+
187
+ unique_matches = np.column_stack([matchups, outcomes])
188
+ unique_matches, weights = np.unique(unique_matches, return_counts=True, axis=0)
189
+
190
+ return (unique_matches[:, :2].astype(np.int32), unique_matches[:, 2], list(all_models), weights.astype(np.float64))
191
+
192
+
193
+ def bt_loss_and_grad(
194
+ ratings: np.ndarray, matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, alpha: float = 1.0
195
+ ) -> Tuple[float, np.ndarray]:
196
+ """Compute Bradley-Terry loss and gradient."""
197
+ matchup_ratings = ratings[matchups]
198
+ logits = alpha * (matchup_ratings[:, 0] - matchup_ratings[:, 1])
199
+ probs = expit(logits)
200
+
201
+ loss = -((np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights).sum()
202
+
203
+ matchups_grads = -alpha * (outcomes - probs) * weights
204
+ model_grad = np.zeros_like(ratings)
205
+ np.add.at(model_grad, matchups[:, [0, 1]], matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64))
206
+
207
+ return loss, model_grad
208
+
209
+
210
+ def fit_bt(
211
+ matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, n_models: int, alpha: float, tol: float = 1e-6
212
+ ) -> np.ndarray:
213
+ """Fit Bradley-Terry model using L-BFGS-B optimization."""
214
+ initial_ratings = np.zeros(n_models, dtype=np.float64)
215
+
216
+ result = minimize(
217
+ fun=bt_loss_and_grad,
218
+ x0=initial_ratings,
219
+ args=(matchups, outcomes, weights, alpha),
220
+ jac=True,
221
+ method="L-BFGS-B",
222
+ options={"disp": False, "maxiter": 100, "gtol": tol},
223
+ )
224
+
225
+ return result["x"]
226
+
227
+
228
+ def scale_and_offset(
229
+ ratings: np.ndarray, models: List[str], scale: float = 400, init_rating: float = 1000
230
+ ) -> np.ndarray:
231
+ """Scale ratings to familiar Elo-like scale."""
232
+ scaled_ratings = (ratings * scale) + init_rating
233
+ return scaled_ratings
234
+
235
+
236
+ def compute_bootstrap_bt(
237
+ data: str,
238
+ num_round: int = 100,
239
+ base: float = 10.0,
240
+ scale: float = 400.0,
241
+ init_rating: float = 1000.0,
242
+ tol: float = 1e-6,
243
+ ) -> pd.DataFrame:
244
+ """Compute bootstrap Bradley-Terry ratings from live votes data."""
245
+ df = load_live_votes(data)
246
+ matchups, outcomes, models, weights = preprocess_for_bt(df)
247
+
248
+ rng = np.random.default_rng(seed=0)
249
+ total_matches = len(df)
250
+ idxs = rng.multinomial(n=total_matches, pvals=weights / weights.sum(), size=num_round)
251
+ boot_weights = idxs.astype(np.float64) / total_matches
252
+
253
+ ratings_list = []
254
+ for sample_weights in boot_weights:
255
+ ratings = fit_bt(
256
+ matchups=matchups,
257
+ outcomes=outcomes,
258
+ weights=sample_weights,
259
+ n_models=len(models),
260
+ alpha=np.log(base),
261
+ tol=tol,
262
+ )
263
+ scaled_ratings = scale_and_offset(ratings=ratings, models=models, scale=scale, init_rating=init_rating)
264
+ ratings_list.append(scaled_ratings)
265
+
266
+ df_ratings = pd.DataFrame(ratings_list, columns=models)
267
+ return df_ratings[df_ratings.median().sort_values(ascending=False).index]
268
+
269
+
270
+ def create_bt_plot(bootstrap_ratings):
271
+ """Create Bradley-Terry ratings plot using Plotly."""
272
+ melted_bootstrap = bootstrap_ratings.melt(id_vars=["Source", "level_1"], var_name="Model", value_name="BT")
273
+ melted_bootstrap = melted_bootstrap.dropna()
274
+ melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "Model", "BT"], ascending=False)
275
+ # Pretty Names
276
+ melted_bootstrap["Model"] = melted_bootstrap["Model"].apply(lambda x: NAME_MAPPING.get(x, x))
277
+ # Compression for Client Side
278
+ melted_bootstrap["BT"] = melted_bootstrap["BT"].apply(lambda x: int(x))
279
+ min_samp = melted_bootstrap[melted_bootstrap["BT"] > 0]["BT"].min()
280
+ max_samp = melted_bootstrap["BT"].max()
281
+ idx_keep = list(range(0, len(melted_bootstrap), 10))
282
+ melted_bootstrap = melted_bootstrap.iloc[idx_keep]
283
+ melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "BT"], ascending=False)
284
+ fig = px.violin(
285
+ melted_bootstrap,
286
+ x="Model",
287
+ y="BT",
288
+ color="Model",
289
+ animation_group="Model",
290
+ animation_frame="Source",
291
+ color_discrete_sequence=COLORS,
292
+ )
293
+
294
+ fig.update_layout(
295
+ autosize=True,
296
+ showlegend=False,
297
+ plot_bgcolor="white",
298
+ title={
299
+ "text": "Talk Arena Live Bradley-Terry Ratings<br>with Bootstrapped Variance",
300
+ "y": 0.9,
301
+ "x": 0.5,
302
+ "xanchor": "center",
303
+ "yanchor": "top",
304
+ },
305
+ xaxis_title="Model",
306
+ yaxis_title="Rating",
307
+ yaxis=dict(gridcolor="#C9CCD1", range=[min_samp - 10, max_samp + 10], griddash="dash"),
308
+ legend=dict(
309
+ orientation="h", # Make legend horizontal
310
+ yanchor="bottom",
311
+ y=-0.5, # Position below plot
312
+ xanchor="center",
313
+ x=0.5, # Center horizontally
314
+ bgcolor="rgba(255, 255, 255, 0.8)",
315
+ bordercolor="#C9CCD1",
316
+ borderwidth=1,
317
+ ),
318
+ margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
319
+ )
320
+
321
+ fig.update_xaxes(showgrid=False)
322
+ fig.update_yaxes(showgrid=True, gridwidth=2)
323
+
324
+ return fig
talk_arena/viz/server.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import textwrap
4
+ import time
5
+ from datetime import datetime
6
+ from typing import Optional
7
+ from zoneinfo import ZoneInfo
8
+
9
+ import plotly.io as pio
10
+ from apscheduler.schedulers.background import BackgroundScheduler
11
+ from fastapi import FastAPI, HTTPException, Response
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ from talk_arena.viz.core import *
15
+
16
+ app = FastAPI(title="Talk Arena API", description="API for Talk Arena leaderboard and statistics", version="0.0.1")
17
+
18
+ # Add CORS middleware
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # In production, replace with specific origins
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+
28
+ # Global variables to store the plots and update time
29
+ class GlobalState:
30
+ WR_PLOT = None
31
+ BT_PLOT = None
32
+ UPDATE_TIME = None
33
+ LAST_PROCESSED = None
34
+ MIN_UPDATE_INTERVAL = 60 # Minimum seconds between updates
35
+
36
+
37
+ state = GlobalState()
38
+
39
+
40
+ def process_and_visualize(force: bool = False):
41
+ """Process data and create visualizations"""
42
+ global state
43
+ current_time = datetime.now(ZoneInfo("America/Los_Angeles"))
44
+
45
+ # Check if enough time has passed since last update
46
+ if not force and state.LAST_PROCESSED:
47
+ time_diff = (current_time - state.LAST_PROCESSED).total_seconds()
48
+ if time_diff < state.MIN_UPDATE_INTERVAL:
49
+ logger.info(f"Skipping update - only {time_diff:.1f} seconds since last update")
50
+ return
51
+
52
+ state.LAST_PROCESSED = current_time
53
+ if state.WR_PLOT is not None and state.BT_PLOT is not None and not force:
54
+ return
55
+
56
+ try:
57
+ # Read JSON data
58
+ pub_json_data = open("/home/wheld3/talk-arena/live_votes.json", "r").read()
59
+ prolific_json_data = open("/home/wheld3/talk-arena/prolific_votes.json", "r").read()
60
+ merged_json_data = json.dumps(
61
+ {"_default": {**json.loads(pub_json_data)["_default"], **json.loads(prolific_json_data)["_default"]}}
62
+ )
63
+
64
+ # Calculate win rates and create plots
65
+ pub_win_rates, pub_votes = calculate_win_rates(pub_json_data)
66
+ pro_win_rates, pro_votes = calculate_win_rates(prolific_json_data)
67
+ total_win_rates, total_votes = calculate_win_rates(merged_json_data)
68
+
69
+ # Process win rates
70
+ all_models = total_win_rates["model"].unique()
71
+ pro_models = pro_win_rates["model"].unique()
72
+ for model in all_models:
73
+ if model not in pro_models:
74
+ new_index = len(pro_win_rates)
75
+ pro_win_rates.loc[new_index] = [model, -0.1, -0.1, -0.2]
76
+
77
+ win_rates = (
78
+ pd.concat([pub_win_rates, pro_win_rates, total_win_rates], keys=["Public", "Prolific", "Total"])
79
+ .reset_index()
80
+ .rename(columns={"level_0": "Source"})
81
+ )
82
+
83
+ state.WR_PLOT = create_win_rate_plot(win_rates)
84
+
85
+ # Calculate Bradley-Terry ratings
86
+ pub_bootstrap_ratings = compute_bootstrap_bt(pub_json_data, num_round=10000)
87
+ pro_bootstrap_ratings = compute_bootstrap_bt(prolific_json_data, num_round=10000)
88
+ total_bootstrap_ratings = compute_bootstrap_bt(merged_json_data, num_round=10000)
89
+
90
+ for model in all_models:
91
+ if model not in pro_models:
92
+ pro_bootstrap_ratings[model] = pro_bootstrap_ratings["diva_3_8b"] * -1
93
+
94
+ bootstrap_ratings = (
95
+ pd.concat(
96
+ [pub_bootstrap_ratings, pro_bootstrap_ratings, total_bootstrap_ratings],
97
+ keys=["Public", "Prolific", "Total"],
98
+ )
99
+ .reset_index()
100
+ .rename(columns={"level_0": "Source"})
101
+ )
102
+
103
+ state.BT_PLOT = create_bt_plot(bootstrap_ratings)
104
+
105
+ # Update timestamp and vote counts
106
+ state.UPDATE_TIME = {
107
+ "timestamp": get_aesthetic_timestamp(),
108
+ "total_votes": total_votes,
109
+ "public_votes": pub_votes,
110
+ "prolific_votes": pro_votes,
111
+ }
112
+
113
+ except Exception as e:
114
+ raise HTTPException(status_code=500, detail=f"Error processing data: {str(e)}")
115
+
116
+
117
+ # Set up logging
118
+ import logging
119
+
120
+
121
+ logging.basicConfig(level=logging.INFO)
122
+ logger = logging.getLogger(__name__)
123
+
124
+ # Global scheduler instance
125
+ scheduler = None
126
+
127
+
128
+ def update_job():
129
+ """Wrapper for the update job with error handling and logging"""
130
+ try:
131
+ logger.info("Starting scheduled update...")
132
+ process_and_visualize(force=True)
133
+ logger.info("Scheduled update completed successfully")
134
+ except Exception as e:
135
+ logger.error(f"Error in scheduled update: {str(e)}", exc_info=True)
136
+
137
+
138
+ @app.on_event("startup")
139
+ async def startup_event():
140
+ """Initialize data and start scheduler"""
141
+ global scheduler
142
+
143
+ try:
144
+ logger.info("Starting initial data processing...")
145
+ process_and_visualize(force=True)
146
+ logger.info("Initial data processing completed")
147
+
148
+ # Clear any existing schedulers
149
+ if scheduler:
150
+ scheduler.shutdown(wait=False)
151
+
152
+ # Initialize and start the scheduler
153
+ scheduler = BackgroundScheduler(
154
+ timezone=ZoneInfo("America/Los_Angeles"), job_defaults={"coalesce": True, "max_instances": 1}
155
+ )
156
+
157
+ # Add the job with proper error handling
158
+ scheduler.add_job(
159
+ func=update_job, # Use the wrapper function
160
+ trigger="interval",
161
+ seconds=300,
162
+ id="update_visualizations",
163
+ name="Update Visualizations",
164
+ misfire_grace_time=60,
165
+ )
166
+
167
+ scheduler.start()
168
+ logger.info("Scheduler started successfully")
169
+
170
+ # Verify the job was added
171
+ jobs = scheduler.get_jobs()
172
+ logger.info(f"Current scheduled jobs: {[job.name for job in jobs]}")
173
+
174
+ except Exception as e:
175
+ logger.error(f"Error during startup: {str(e)}", exc_info=True)
176
+ raise
177
+
178
+
179
+ @app.on_event("shutdown")
180
+ async def shutdown_event():
181
+ """Properly shutdown the scheduler when the app stops"""
182
+ global scheduler
183
+ try:
184
+ if scheduler:
185
+ logger.info("Shutting down scheduler...")
186
+ scheduler.shutdown(wait=False)
187
+ logger.info("Scheduler shutdown complete")
188
+ except Exception as e:
189
+ logger.error(f"Error during scheduler shutdown: {str(e)}", exc_info=True)
190
+
191
+
192
+ # Add an endpoint to manually trigger an update
193
+ @app.post("/api/trigger-update")
194
+ async def trigger_update():
195
+ """Manually trigger a data update"""
196
+ try:
197
+ logger.info("Manual update triggered")
198
+ process_and_visualize(force=True)
199
+ logger.info("Manual update completed")
200
+ return {"status": "success", "message": "Update completed"}
201
+ except Exception as e:
202
+ logger.error(f"Error in manual update: {str(e)}", exc_info=True)
203
+ raise HTTPException(status_code=500, detail=str(e))
204
+
205
+
206
+ def generate_etag(data: dict) -> str:
207
+ """Generate an ETag for the given data"""
208
+ # Convert data to a consistent string representation and hash it
209
+ data_str = json.dumps(data, sort_keys=True)
210
+ return hashlib.md5(data_str.encode()).hexdigest()
211
+
212
+
213
+ @app.get("/api/win-rate-plot")
214
+ async def get_wr_plot(response: Response):
215
+ """Get the win rate plot data"""
216
+ if state.WR_PLOT is None:
217
+ raise HTTPException(status_code=503, detail="Plot data not yet available")
218
+
219
+ plot_json = json.loads(pio.to_json(state.WR_PLOT))
220
+
221
+ # Customize animation settings
222
+ for step in plot_json["layout"]["sliders"][0]["steps"]:
223
+ step["args"][1]["frame"]["duration"] = 500
224
+ step["args"][1]["transition"]["duration"] = 500
225
+
226
+ plot_json["layout"]["updatemenus"] = []
227
+ plot_json["layout"]["sliders"][0]["len"] = 0.8
228
+ plot_json["layout"]["sliders"][0]["pad"] = {}
229
+
230
+ # Generate ETag
231
+ etag = generate_etag(plot_json)
232
+ response.headers["ETag"] = etag
233
+
234
+ # Set cache control headers - cache for 4 minutes since we update every 5
235
+ response.headers["Cache-Control"] = "public, max-age=240"
236
+
237
+ return plot_json
238
+
239
+
240
+ @app.get("/api/bt-plot")
241
+ async def get_bt_plot(response: Response):
242
+ """Get the Bradley-Terry plot data"""
243
+ if state.BT_PLOT is None:
244
+ raise HTTPException(status_code=503, detail="Plot data not yet available")
245
+
246
+ plot_json = json.loads(pio.to_json(state.BT_PLOT))
247
+
248
+ # Customize animation settings
249
+ for step in plot_json["layout"]["sliders"][0]["steps"]:
250
+ step["args"][1]["frame"]["duration"] = 500
251
+ step["args"][1]["transition"]["duration"] = 500
252
+
253
+ plot_json["layout"]["updatemenus"] = []
254
+ plot_json["layout"]["sliders"][0]["len"] = 0.8
255
+ plot_json["layout"]["sliders"][0]["pad"] = {}
256
+
257
+ # Generate ETag
258
+ etag = generate_etag(plot_json)
259
+ response.headers["ETag"] = etag
260
+
261
+ # Set cache control headers - cache for 4 minutes since we update every 5
262
+ response.headers["Cache-Control"] = "public, max-age=240"
263
+
264
+ return plot_json
265
+
266
+
267
+ @app.get("/api/update-time")
268
+ async def get_update_time(response: Response):
269
+ """Get the last update time and vote counts"""
270
+ if state.UPDATE_TIME is None:
271
+ raise HTTPException(status_code=503, detail="Update time not yet available")
272
+
273
+ # Generate ETag
274
+ etag = generate_etag(state.UPDATE_TIME)
275
+ response.headers["ETag"] = etag
276
+
277
+ # Set cache control headers - cache for 4 minutes
278
+ response.headers["Cache-Control"] = "public, max-age=240"
279
+
280
+ return state.UPDATE_TIME
281
+
282
+
283
+ @app.get("/api/health")
284
+ async def health_check(response: Response):
285
+ """Enhanced health check endpoint with scheduler status"""
286
+ global scheduler
287
+
288
+ scheduler_status = "not_running"
289
+ next_run = None
290
+ last_run = state.UPDATE_TIME["timestamp"] if state.UPDATE_TIME else None
291
+
292
+ if scheduler:
293
+ try:
294
+ jobs = scheduler.get_jobs()
295
+ if jobs:
296
+ scheduler_status = "running"
297
+ next_run = jobs[0].next_run_time.strftime("%Y-%m-%d %H:%M:%S %Z")
298
+ except Exception as e:
299
+ logger.error(f"Error checking scheduler status: {str(e)}")
300
+ scheduler_status = f"error: {str(e)}"
301
+
302
+ health_data = {
303
+ "status": "healthy",
304
+ "scheduler_status": scheduler_status,
305
+ "last_update": last_run,
306
+ "next_scheduled_update": next_run,
307
+ "current_time": datetime.now(ZoneInfo("America/Los_Angeles")).strftime("%Y-%m-%d %H:%M:%S %Z"),
308
+ }
309
+
310
+ # Generate ETag
311
+ etag = generate_etag(health_data)
312
+ response.headers["ETag"] = etag
313
+
314
+ # Set cache control headers - short cache time for health check
315
+ response.headers["Cache-Control"] = "public, max-age=30"
316
+
317
+ return health_data
318
+
319
+
320
+ if __name__ == "__main__":
321
+ import uvicorn
322
+
323
+ uvicorn.run(app, host="0.0.0.0", port=8000)