Spaces:
Running
Running
Ali Sartaz Khan
commited on
Commit
·
3c8c320
1
Parent(s):
c2b5b47
Add application file
Browse files- .gradio/certificate.pem +31 -0
- app.py +8 -0
- audio_out_votes.json +0 -0
- requirements.txt +15 -0
- talk_arena/.env +2 -0
- talk_arena/__init__.py +0 -0
- talk_arena/__pycache__/__init__.cpython-312.pyc +0 -0
- talk_arena/__pycache__/db_utils.cpython-312.pyc +0 -0
- talk_arena/__pycache__/streaming_helpers.cpython-312.pyc +0 -0
- talk_arena/audio_collection.py +448 -0
- talk_arena/db_utils.py +37 -0
- talk_arena/demo.py +432 -0
- talk_arena/leaderboard_viz.py +463 -0
- talk_arena/streaming_helpers.py +348 -0
- talk_arena/styles.css +25 -0
- talk_arena/viz/core.py +324 -0
- talk_arena/viz/server.py +323 -0
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
app.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
from talk_arena.audio_collection import demo
|
3 |
+
import sys
|
4 |
+
sys.path.append("talk-arena")
|
5 |
+
from talk_arena.audio_collection import demo
|
6 |
+
|
7 |
+
demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
|
8 |
+
|
audio_out_votes.json
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.45.2
|
2 |
+
transformers-stream-generator==0.0.5
|
3 |
+
accelerate>=0.26.0
|
4 |
+
peft
|
5 |
+
gradio==5.8.0
|
6 |
+
tinydb==4.8.0
|
7 |
+
xxhash==3.4.1
|
8 |
+
google-ai-generativelanguage==0.6.10
|
9 |
+
google-generativeai
|
10 |
+
datasets==2.18.0
|
11 |
+
librosa==0.10.1
|
12 |
+
soundfile==0.12.1
|
13 |
+
openai==1.52.0
|
14 |
+
python-dotenv==1.0.1
|
15 |
+
httpx==0.27.2
|
talk_arena/.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_API_KEY="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"
|
2 |
+
GEMINI_API_KEY="AIzaSyAM6XTT9S9nzE09jj5o-UNDZ4f8INPyWBM"
|
talk_arena/__init__.py
ADDED
File without changes
|
talk_arena/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (175 Bytes). View file
|
|
talk_arena/__pycache__/db_utils.cpython-312.pyc
ADDED
Binary file (2.62 kB). View file
|
|
talk_arena/__pycache__/streaming_helpers.cpython-312.pyc
ADDED
Binary file (18.8 kB). View file
|
|
talk_arena/audio_collection.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import textwrap
|
6 |
+
import time
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import numpy as np
|
10 |
+
import soundfile as sf
|
11 |
+
import xxhash
|
12 |
+
from datasets import Audio
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
from openai import OpenAI
|
15 |
+
|
16 |
+
import talk_arena.streaming_helpers as sh
|
17 |
+
from talk_arena.db_utils import TinyThreadSafeDB
|
18 |
+
|
19 |
+
load_dotenv()
|
20 |
+
resampler = Audio(sampling_rate=16_000)
|
21 |
+
|
22 |
+
# /nlp/scr/askhan1/audioLLM_as_a_judge/talk-arena../src/talk_arena/audio_collection.py
|
23 |
+
def parse_args():
|
24 |
+
parser = argparse.ArgumentParser(description="Talk Arena Demo")
|
25 |
+
parser.add_argument("--free_only", action="store_true", help="Only use free models")
|
26 |
+
return parser.parse_args()
|
27 |
+
|
28 |
+
|
29 |
+
args = parse_args()
|
30 |
+
|
31 |
+
if gr.NO_RELOAD: # Prevents Re-init during hot reloading
|
32 |
+
# Transcription Disabled for Public Interface
|
33 |
+
# asr_pipe = pipeline(
|
34 |
+
# task="automatic-speech-recognition",
|
35 |
+
# model="openai/whisper-large-v3-turbo",
|
36 |
+
# chunk_length_s=30,
|
37 |
+
# device="cuda:1",
|
38 |
+
# )
|
39 |
+
|
40 |
+
anonymous = True
|
41 |
+
|
42 |
+
gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o")
|
43 |
+
gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp")
|
44 |
+
competitor_info = [
|
45 |
+
(sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"),
|
46 |
+
(sh.gradio_gen_factory(gemini2_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"),
|
47 |
+
]
|
48 |
+
|
49 |
+
resp_generators = [generator for generator, _, _ in competitor_info]
|
50 |
+
model_shorthand = [shorthand for _, shorthand, _ in competitor_info]
|
51 |
+
model_name = [full_name for _, _, full_name in competitor_info]
|
52 |
+
all_models = list(range(len(model_shorthand)))
|
53 |
+
|
54 |
+
|
55 |
+
async def pairwise_response_async(audio_input, state, model_order):
|
56 |
+
if audio_input == None:
|
57 |
+
raise StopAsyncIteration(
|
58 |
+
"",
|
59 |
+
"",
|
60 |
+
gr.Button(visible=False),
|
61 |
+
gr.Button(visible=False),
|
62 |
+
gr.Button(visible=False),
|
63 |
+
state,
|
64 |
+
audio_input,
|
65 |
+
None,
|
66 |
+
None,
|
67 |
+
None,
|
68 |
+
)
|
69 |
+
spinner_id = 0
|
70 |
+
spinners = ["◐ ", "◓ ", "◑", "◒"]
|
71 |
+
spinner = spinners[0]
|
72 |
+
gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]]
|
73 |
+
latencies = [{}, {}] # Store timing info for each model
|
74 |
+
resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)]
|
75 |
+
tts_resps = [gr.Audio(), gr.Audio()]
|
76 |
+
error_in_model = False
|
77 |
+
for order, generator in enumerate(gen_pair):
|
78 |
+
start_time = time.time()
|
79 |
+
first_token = True
|
80 |
+
total_length = 0
|
81 |
+
try:
|
82 |
+
async for local_resp in generator(audio_input, order):
|
83 |
+
total_length += 1
|
84 |
+
if first_token:
|
85 |
+
latencies[order]["time_to_first_token"] = time.time() - start_time
|
86 |
+
first_token = False
|
87 |
+
resps[order] = local_resp
|
88 |
+
spinner = spinners[spinner_id]
|
89 |
+
spinner_id = (spinner_id + 1) % 4
|
90 |
+
yield (
|
91 |
+
gr.Button(
|
92 |
+
value=spinner + " Generating Responses " + spinner,
|
93 |
+
interactive=False,
|
94 |
+
variant="primary",
|
95 |
+
),
|
96 |
+
resps[0],
|
97 |
+
resps[1],
|
98 |
+
tts_resps[0],
|
99 |
+
tts_resps[1],
|
100 |
+
gr.Button(visible=False),
|
101 |
+
gr.Button(visible=False),
|
102 |
+
gr.Button(visible=False),
|
103 |
+
state,
|
104 |
+
audio_input,
|
105 |
+
None,
|
106 |
+
None,
|
107 |
+
latencies,
|
108 |
+
)
|
109 |
+
latencies[order]["total_time"] = time.time() - start_time
|
110 |
+
latencies[order]["response_length"] = total_length
|
111 |
+
except:
|
112 |
+
error_in_model = True
|
113 |
+
resps[order] = gr.Textbox(
|
114 |
+
info=f"<strong>Error thrown by Model {order+1} API</strong>",
|
115 |
+
value="" if first_token else resps[order]._constructor_args[0]["value"],
|
116 |
+
visible=True,
|
117 |
+
label=f"Model {order+1}",
|
118 |
+
)
|
119 |
+
yield (
|
120 |
+
gr.Button(
|
121 |
+
value=spinner + " Generating Responses " + spinner,
|
122 |
+
interactive=False,
|
123 |
+
variant="primary",
|
124 |
+
),
|
125 |
+
resps[0],
|
126 |
+
resps[1],
|
127 |
+
tts_resps[0],
|
128 |
+
tts_resps[1],
|
129 |
+
gr.Button(visible=False),
|
130 |
+
gr.Button(visible=False),
|
131 |
+
gr.Button(visible=False),
|
132 |
+
state,
|
133 |
+
audio_input,
|
134 |
+
None,
|
135 |
+
None,
|
136 |
+
latencies,
|
137 |
+
)
|
138 |
+
|
139 |
+
sr, y = audio_input
|
140 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
141 |
+
y = y.astype(np.float32)
|
142 |
+
y /= np.max(np.abs(y))
|
143 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
144 |
+
sf.write(f"{x}_resp{order}.wav", a["array"], a["sampling_rate"], format="wav")
|
145 |
+
tts_options = {
|
146 |
+
"model": "gpt-4o-mini-tts",
|
147 |
+
"voice": "alloy",
|
148 |
+
"input": resps[order].__dict__["_constructor_args"][0]["value"],
|
149 |
+
"response_format": "wav",
|
150 |
+
}
|
151 |
+
abytes = OpenAI(api_key=os.environ["OPENAI_API_KEY"]).audio.speech.create(**tts_options).content
|
152 |
+
tts_resps[order] = gr.Audio(
|
153 |
+
value=abytes,
|
154 |
+
visible=True,
|
155 |
+
)
|
156 |
+
latencies[order]["total_time"] = time.time() - start_time
|
157 |
+
latencies[order]["response_length"] = total_length
|
158 |
+
print(latencies)
|
159 |
+
yield (
|
160 |
+
gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False),
|
161 |
+
resps[0],
|
162 |
+
resps[1],
|
163 |
+
tts_resps[0],
|
164 |
+
tts_resps[1],
|
165 |
+
gr.Button(visible=not error_in_model),
|
166 |
+
gr.Button(visible=not error_in_model),
|
167 |
+
gr.Button(visible=not error_in_model),
|
168 |
+
responses_complete(state),
|
169 |
+
audio_input,
|
170 |
+
gr.Textbox(visible=False),
|
171 |
+
gr.Audio(visible=False),
|
172 |
+
latencies,
|
173 |
+
)
|
174 |
+
|
175 |
+
|
176 |
+
def on_page_load(state, model_order):
|
177 |
+
if state == 0:
|
178 |
+
# gr.Info(
|
179 |
+
# "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant,"
|
180 |
+
# " or ChatGPT for."
|
181 |
+
# )
|
182 |
+
state = 1
|
183 |
+
model_order = random.sample(all_models, 2) if anonymous else model_order
|
184 |
+
return state, model_order
|
185 |
+
|
186 |
+
|
187 |
+
def recording_complete(state):
|
188 |
+
if state == 1:
|
189 |
+
# gr.Info(
|
190 |
+
# "Once you submit your recording, you'll receive responses from different models. This might take a second."
|
191 |
+
# )
|
192 |
+
state = 2
|
193 |
+
return (
|
194 |
+
gr.Button(value="Starting Generation", interactive=False, variant="primary"),
|
195 |
+
state,
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def responses_complete(state):
|
200 |
+
if state == 2:
|
201 |
+
gr.Info(
|
202 |
+
"Give us your feedback! Mark which model gave you the best response so we can understand the quality of"
|
203 |
+
" these different voice assistant models."
|
204 |
+
)
|
205 |
+
state = 3
|
206 |
+
return state
|
207 |
+
|
208 |
+
|
209 |
+
def clear_factory(button_id):
|
210 |
+
async def clear(audio_input, model_order, pref_counter, reasoning, latency):
|
211 |
+
textbox1 = gr.Textbox(visible=False)
|
212 |
+
textbox2 = gr.Textbox(visible=False)
|
213 |
+
if button_id != None:
|
214 |
+
sr, y = audio_input
|
215 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
216 |
+
await db.insert(
|
217 |
+
{
|
218 |
+
"audio_hash": x,
|
219 |
+
"outcome": button_id,
|
220 |
+
"model_a": model_shorthand[model_order[0]],
|
221 |
+
"model_b": model_shorthand[model_order[1]],
|
222 |
+
"why": reasoning,
|
223 |
+
"model_a_latency": latency[0],
|
224 |
+
"model_b_latency": latency[1],
|
225 |
+
}
|
226 |
+
)
|
227 |
+
pref_counter += 1
|
228 |
+
model_a = model_name[model_order[0]]
|
229 |
+
model_b = model_name[model_order[1]]
|
230 |
+
|
231 |
+
counter_text = f"# {pref_counter}/10 Preferences Submitted"
|
232 |
+
if pref_counter >= 10:
|
233 |
+
code = "C1ARB3D6"
|
234 |
+
counter_text = f"# Completed! Completion Code: {code}"
|
235 |
+
if anonymous:
|
236 |
+
model_order = random.sample(all_models, 2)
|
237 |
+
return (
|
238 |
+
model_order,
|
239 |
+
gr.Button(
|
240 |
+
value="Record Audio to Submit Again!",
|
241 |
+
interactive=False,
|
242 |
+
visible=True,
|
243 |
+
),
|
244 |
+
gr.Button(visible=False),
|
245 |
+
gr.Button(visible=False),
|
246 |
+
gr.Button(visible=False),
|
247 |
+
None,
|
248 |
+
textbox1,
|
249 |
+
textbox2,
|
250 |
+
gr.Audio(visible=False),
|
251 |
+
gr.Audio(visible=False),
|
252 |
+
pref_counter,
|
253 |
+
counter_text,
|
254 |
+
gr.Textbox(visible=False),
|
255 |
+
gr.Audio(visible=False),
|
256 |
+
)
|
257 |
+
|
258 |
+
return clear
|
259 |
+
|
260 |
+
|
261 |
+
def transcribe(transc, voice_reason):
|
262 |
+
if transc is None:
|
263 |
+
transc = ""
|
264 |
+
transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
|
265 |
+
return transc, gr.Audio(value=None)
|
266 |
+
|
267 |
+
|
268 |
+
theme = gr.themes.Soft(
|
269 |
+
primary_hue=gr.themes.Color(
|
270 |
+
c100="#82000019",
|
271 |
+
c200="#82000033",
|
272 |
+
c300="#8200004c",
|
273 |
+
c400="#82000066",
|
274 |
+
c50="#8200007f",
|
275 |
+
c500="#8200007f",
|
276 |
+
c600="#82000099",
|
277 |
+
c700="#820000b2",
|
278 |
+
c800="#820000cc",
|
279 |
+
c900="#820000e5",
|
280 |
+
c950="#820000f2",
|
281 |
+
),
|
282 |
+
secondary_hue="rose",
|
283 |
+
neutral_hue="stone",
|
284 |
+
)
|
285 |
+
|
286 |
+
with open("../src/talk_arena/styles.css", "r") as css_file:
|
287 |
+
custom_css = css_file.read()
|
288 |
+
|
289 |
+
db = TinyThreadSafeDB("audio_out_votes.json")
|
290 |
+
|
291 |
+
with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo:
|
292 |
+
submitted_preferences = gr.State(0)
|
293 |
+
state = gr.State(0)
|
294 |
+
model_order = gr.State([])
|
295 |
+
latency = gr.State([])
|
296 |
+
with gr.Row():
|
297 |
+
counter_text = gr.Markdown(
|
298 |
+
"# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference."
|
299 |
+
)
|
300 |
+
category_description_text = gr.Markdown("PLACEHOLDER FOR ALI TO FILL IN LATER")
|
301 |
+
with gr.Row():
|
302 |
+
audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input")
|
303 |
+
|
304 |
+
with gr.Row(equal_height=True):
|
305 |
+
with gr.Column(scale=1):
|
306 |
+
out1 = gr.Textbox(visible=False, lines=5, autoscroll=True)
|
307 |
+
audio_out1 = gr.Audio(visible=False)
|
308 |
+
with gr.Column(scale=1):
|
309 |
+
out2 = gr.Textbox(visible=False, lines=5, autoscroll=True)
|
310 |
+
audio_out2 = gr.Audio(visible=False)
|
311 |
+
|
312 |
+
with gr.Row():
|
313 |
+
btn = gr.Button(value="Record Audio to Submit!", interactive=False)
|
314 |
+
|
315 |
+
with gr.Row(equal_height=True):
|
316 |
+
reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4)
|
317 |
+
reason_record = gr.Audio(
|
318 |
+
sources=["microphone"],
|
319 |
+
interactive=True,
|
320 |
+
streaming=False,
|
321 |
+
label="Speak to transcribe!",
|
322 |
+
visible=False,
|
323 |
+
type="filepath",
|
324 |
+
# waveform_options={"show_recording_waveform": False},
|
325 |
+
scale=1,
|
326 |
+
)
|
327 |
+
|
328 |
+
with gr.Row():
|
329 |
+
best1 = gr.Button(value="Model 1 is better", visible=False)
|
330 |
+
tie = gr.Button(value="Tie", visible=False)
|
331 |
+
best2 = gr.Button(value="Model 2 is better", visible=False)
|
332 |
+
|
333 |
+
with gr.Row():
|
334 |
+
contact = gr.Markdown("")
|
335 |
+
|
336 |
+
# reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record])
|
337 |
+
audio_input.stop_recording(
|
338 |
+
recording_complete,
|
339 |
+
[state],
|
340 |
+
[btn, state],
|
341 |
+
).then(
|
342 |
+
fn=pairwise_response_async,
|
343 |
+
inputs=[audio_input, state, model_order],
|
344 |
+
outputs=[
|
345 |
+
btn,
|
346 |
+
out1,
|
347 |
+
out2,
|
348 |
+
audio_out1,
|
349 |
+
audio_out2,
|
350 |
+
best1,
|
351 |
+
best2,
|
352 |
+
tie,
|
353 |
+
state,
|
354 |
+
audio_input,
|
355 |
+
reason,
|
356 |
+
reason_record,
|
357 |
+
latency,
|
358 |
+
],
|
359 |
+
)
|
360 |
+
audio_input.start_recording(
|
361 |
+
lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"),
|
362 |
+
None,
|
363 |
+
btn,
|
364 |
+
)
|
365 |
+
best1.click(
|
366 |
+
fn=clear_factory(0),
|
367 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
368 |
+
outputs=[
|
369 |
+
model_order,
|
370 |
+
btn,
|
371 |
+
best1,
|
372 |
+
best2,
|
373 |
+
tie,
|
374 |
+
audio_input,
|
375 |
+
out1,
|
376 |
+
out2,
|
377 |
+
audio_out1,
|
378 |
+
audio_out2,
|
379 |
+
submitted_preferences,
|
380 |
+
counter_text,
|
381 |
+
reason,
|
382 |
+
reason_record,
|
383 |
+
],
|
384 |
+
)
|
385 |
+
tie.click(
|
386 |
+
fn=clear_factory(0.5),
|
387 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
388 |
+
outputs=[
|
389 |
+
model_order,
|
390 |
+
btn,
|
391 |
+
best1,
|
392 |
+
best2,
|
393 |
+
tie,
|
394 |
+
audio_input,
|
395 |
+
out1,
|
396 |
+
out2,
|
397 |
+
audio_out1,
|
398 |
+
audio_out2,
|
399 |
+
submitted_preferences,
|
400 |
+
counter_text,
|
401 |
+
reason,
|
402 |
+
reason_record,
|
403 |
+
],
|
404 |
+
)
|
405 |
+
best2.click(
|
406 |
+
fn=clear_factory(1),
|
407 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
408 |
+
outputs=[
|
409 |
+
model_order,
|
410 |
+
btn,
|
411 |
+
best1,
|
412 |
+
best2,
|
413 |
+
tie,
|
414 |
+
audio_input,
|
415 |
+
out1,
|
416 |
+
out2,
|
417 |
+
audio_out1,
|
418 |
+
audio_out2,
|
419 |
+
submitted_preferences,
|
420 |
+
counter_text,
|
421 |
+
reason,
|
422 |
+
reason_record,
|
423 |
+
],
|
424 |
+
)
|
425 |
+
audio_input.clear(
|
426 |
+
clear_factory(None),
|
427 |
+
[audio_input, model_order, submitted_preferences, reason, latency],
|
428 |
+
[
|
429 |
+
model_order,
|
430 |
+
btn,
|
431 |
+
best1,
|
432 |
+
best2,
|
433 |
+
tie,
|
434 |
+
audio_input,
|
435 |
+
out1,
|
436 |
+
out2,
|
437 |
+
audio_out1,
|
438 |
+
audio_out2,
|
439 |
+
submitted_preferences,
|
440 |
+
counter_text,
|
441 |
+
reason,
|
442 |
+
reason_record,
|
443 |
+
],
|
444 |
+
)
|
445 |
+
demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order])
|
446 |
+
|
447 |
+
if __name__ == "__main__":
|
448 |
+
demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
|
talk_arena/db_utils.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import uuid
|
2 |
+
from asyncio import Lock as ALock
|
3 |
+
from contextlib import asynccontextmanager
|
4 |
+
from threading import Lock as TLock
|
5 |
+
|
6 |
+
from tinydb import TinyDB
|
7 |
+
from tinydb.table import Table as TinyDBTable
|
8 |
+
|
9 |
+
|
10 |
+
class UUIDTable(TinyDBTable):
|
11 |
+
document_id_class = uuid.UUID
|
12 |
+
|
13 |
+
def _get_next_id(self):
|
14 |
+
return uuid.uuid4()
|
15 |
+
|
16 |
+
|
17 |
+
class UUIDB(TinyDB):
|
18 |
+
table_class = UUIDTable
|
19 |
+
|
20 |
+
|
21 |
+
class TinyThreadSafeDB:
|
22 |
+
def __init__(self, db_path: str):
|
23 |
+
self.db = UUIDB(db_path)
|
24 |
+
self._lock1 = TLock()
|
25 |
+
self._lock2 = ALock()
|
26 |
+
|
27 |
+
@asynccontextmanager
|
28 |
+
async def atomic_operation(self):
|
29 |
+
"""Context manager for thread-safe database operations"""
|
30 |
+
with self._lock1:
|
31 |
+
async with self._lock2:
|
32 |
+
yield self.db
|
33 |
+
|
34 |
+
async def insert(self, data: dict):
|
35 |
+
"""Thread-safe insertion of preference data"""
|
36 |
+
async with self.atomic_operation() as db:
|
37 |
+
db.insert(data)
|
talk_arena/demo.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import random
|
4 |
+
import textwrap
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import xxhash
|
9 |
+
from dotenv import load_dotenv
|
10 |
+
from transformers import pipeline
|
11 |
+
|
12 |
+
import talk_arena.streaming_helpers as sh
|
13 |
+
from talk_arena.db_utils import TinyThreadSafeDB
|
14 |
+
|
15 |
+
|
16 |
+
load_dotenv()
|
17 |
+
|
18 |
+
|
19 |
+
def parse_args():
|
20 |
+
parser = argparse.ArgumentParser(description="Talk Arena Demo")
|
21 |
+
parser.add_argument("--free_only", action="store_true", help="Only use free models")
|
22 |
+
return parser.parse_args()
|
23 |
+
|
24 |
+
|
25 |
+
args = parse_args()
|
26 |
+
|
27 |
+
if gr.NO_RELOAD: # Prevents Re-init during hot reloading
|
28 |
+
# Transcription Disabled for Public Interface
|
29 |
+
# asr_pipe = pipeline(
|
30 |
+
# task="automatic-speech-recognition",
|
31 |
+
# model="openai/whisper-large-v3-turbo",
|
32 |
+
# chunk_length_s=30,
|
33 |
+
# device="cuda:1",
|
34 |
+
# )
|
35 |
+
|
36 |
+
anonymous = True
|
37 |
+
|
38 |
+
# Generation Setup
|
39 |
+
diva_audio, diva = sh.api_streaming("WillHeld/DiVA-llama-3-v0-8b")
|
40 |
+
qwen2_audio, qwen2 = sh.api_streaming("Qwen/Qwen2-Audio-7B-Instruct")
|
41 |
+
pipelined_system, pipeline_model = sh.api_streaming("pipeline/meta-llama/Meta-Llama-3-8B-Instruct")
|
42 |
+
if not args.free_only:
|
43 |
+
gemini_audio, gemini_model = sh.gemini_streaming("models/gemini-1.5-flash")
|
44 |
+
gpt4o_audio, gpt4o_model = sh.gpt4o_streaming("models/gpt4o")
|
45 |
+
geminip_audio, geminip_model = sh.gemini_streaming("models/gemini-1.5-pro")
|
46 |
+
gemini2_audio, gemini2_model = sh.gemini_streaming("models/gemini-2.0-flash-exp")
|
47 |
+
typhoon_audio, typhoon_model = sh.api_streaming("scb10x/llama-3-typhoon-audio-8b-2411")
|
48 |
+
|
49 |
+
competitor_info = [
|
50 |
+
(sh.gradio_gen_factory(diva_audio, "DiVA Llama 3 8B", anonymous), "diva_3_8b", "DiVA Llama 3 8B"),
|
51 |
+
(sh.gradio_gen_factory(qwen2_audio, "Qwen 2", anonymous), "qwen2", "Qwen 2 Audio"),
|
52 |
+
(
|
53 |
+
sh.gradio_gen_factory(pipelined_system, "Pipelined Llama 3 8B", anonymous),
|
54 |
+
"pipe_l3.0",
|
55 |
+
"Pipelined Llama 3 8B",
|
56 |
+
),
|
57 |
+
(sh.gradio_gen_factory(typhoon_audio, "Typhoon Audio", anonymous), "typhoon_audio", "Typhoon Audio"),
|
58 |
+
]
|
59 |
+
# Add paid models if flag is not set
|
60 |
+
if not args.free_only:
|
61 |
+
competitor_info += [
|
62 |
+
(sh.gradio_gen_factory(gemini_audio, "Gemini 1.5 Flash", anonymous), "gemini_1.5f", "Gemini 1.5 Flash"),
|
63 |
+
(sh.gradio_gen_factory(gpt4o_audio, "GPT4o", anonymous), "gpt4o", "GPT-4o"),
|
64 |
+
(sh.gradio_gen_factory(geminip_audio, "Gemini 1.5 Pro", anonymous), "gemini_1.5p", "Gemini 1.5 Pro"),
|
65 |
+
(sh.gradio_gen_factory(geminip_audio, "Gemini 2 Flash", anonymous), "gemini_2f", "Gemini 2 Flash"),
|
66 |
+
]
|
67 |
+
|
68 |
+
resp_generators = [generator for generator, _, _ in competitor_info]
|
69 |
+
model_shorthand = [shorthand for _, shorthand, _ in competitor_info]
|
70 |
+
model_name = [full_name for _, _, full_name in competitor_info]
|
71 |
+
all_models = list(range(len(model_shorthand)))
|
72 |
+
|
73 |
+
|
74 |
+
async def pairwise_response_async(audio_input, state, model_order):
|
75 |
+
if audio_input == None:
|
76 |
+
raise StopAsyncIteration(
|
77 |
+
"",
|
78 |
+
"",
|
79 |
+
gr.Button(visible=False),
|
80 |
+
gr.Button(visible=False),
|
81 |
+
gr.Button(visible=False),
|
82 |
+
state,
|
83 |
+
audio_input,
|
84 |
+
None,
|
85 |
+
None,
|
86 |
+
None,
|
87 |
+
)
|
88 |
+
spinner_id = 0
|
89 |
+
spinners = ["◐ ", "◓ ", "◑", "◒"]
|
90 |
+
spinner = spinners[0]
|
91 |
+
gen_pair = [resp_generators[model_order[0]], resp_generators[model_order[1]]]
|
92 |
+
latencies = [{}, {}] # Store timing info for each model
|
93 |
+
resps = [gr.Textbox(value="", info="", visible=False), gr.Textbox(value="", info="", visible=False)]
|
94 |
+
|
95 |
+
error_in_model = False
|
96 |
+
for order, generator in enumerate(gen_pair):
|
97 |
+
start_time = time.time()
|
98 |
+
first_token = True
|
99 |
+
total_length = 0
|
100 |
+
try:
|
101 |
+
async for local_resp in generator(audio_input, order):
|
102 |
+
total_length += 1
|
103 |
+
if first_token:
|
104 |
+
latencies[order]["time_to_first_token"] = time.time() - start_time
|
105 |
+
first_token = False
|
106 |
+
resps[order] = local_resp
|
107 |
+
spinner = spinners[spinner_id]
|
108 |
+
spinner_id = (spinner_id + 1) % 4
|
109 |
+
yield (
|
110 |
+
gr.Button(
|
111 |
+
value=spinner + " Generating Responses " + spinner,
|
112 |
+
interactive=False,
|
113 |
+
variant="primary",
|
114 |
+
),
|
115 |
+
resps[0],
|
116 |
+
resps[1],
|
117 |
+
gr.Button(visible=False),
|
118 |
+
gr.Button(visible=False),
|
119 |
+
gr.Button(visible=False),
|
120 |
+
state,
|
121 |
+
audio_input,
|
122 |
+
None,
|
123 |
+
None,
|
124 |
+
latencies,
|
125 |
+
)
|
126 |
+
latencies[order]["total_time"] = time.time() - start_time
|
127 |
+
latencies[order]["response_length"] = total_length
|
128 |
+
except:
|
129 |
+
error_in_model = True
|
130 |
+
resps[order] = gr.Textbox(
|
131 |
+
info=f"<strong>Error thrown by Model {order+1} API</strong>",
|
132 |
+
value="" if first_token else resps[order]._constructor_args[0]["value"],
|
133 |
+
visible=True,
|
134 |
+
label=f"Model {order+1}",
|
135 |
+
)
|
136 |
+
yield (
|
137 |
+
gr.Button(
|
138 |
+
value=spinner + " Generating Responses " + spinner,
|
139 |
+
interactive=False,
|
140 |
+
variant="primary",
|
141 |
+
),
|
142 |
+
resps[0],
|
143 |
+
resps[1],
|
144 |
+
gr.Button(visible=False),
|
145 |
+
gr.Button(visible=False),
|
146 |
+
gr.Button(visible=False),
|
147 |
+
state,
|
148 |
+
audio_input,
|
149 |
+
None,
|
150 |
+
None,
|
151 |
+
latencies,
|
152 |
+
)
|
153 |
+
latencies[order]["total_time"] = time.time() - start_time
|
154 |
+
latencies[order]["response_length"] = total_length
|
155 |
+
print(latencies)
|
156 |
+
yield (
|
157 |
+
gr.Button(value="Vote for which model is better!", interactive=False, variant="primary", visible=False),
|
158 |
+
resps[0],
|
159 |
+
resps[1],
|
160 |
+
gr.Button(visible=not error_in_model),
|
161 |
+
gr.Button(visible=not error_in_model),
|
162 |
+
gr.Button(visible=not error_in_model),
|
163 |
+
responses_complete(state),
|
164 |
+
audio_input,
|
165 |
+
gr.Textbox(visible=False),
|
166 |
+
gr.Audio(visible=False),
|
167 |
+
latencies,
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def on_page_load(state, model_order):
|
172 |
+
if state == 0:
|
173 |
+
# gr.Info(
|
174 |
+
# "Record something you'd say to an AI Assistant! Think about what you usually use Siri, Google Assistant,"
|
175 |
+
# " or ChatGPT for."
|
176 |
+
# )
|
177 |
+
state = 1
|
178 |
+
model_order = random.sample(all_models, 2) if anonymous else model_order
|
179 |
+
return state, model_order
|
180 |
+
|
181 |
+
|
182 |
+
def recording_complete(state):
|
183 |
+
if state == 1:
|
184 |
+
# gr.Info(
|
185 |
+
# "Once you submit your recording, you'll receive responses from different models. This might take a second."
|
186 |
+
# )
|
187 |
+
state = 2
|
188 |
+
return (
|
189 |
+
gr.Button(value="Starting Generation", interactive=False, variant="primary"),
|
190 |
+
state,
|
191 |
+
)
|
192 |
+
|
193 |
+
|
194 |
+
def responses_complete(state):
|
195 |
+
if state == 2:
|
196 |
+
gr.Info(
|
197 |
+
"Give us your feedback! Mark which model gave you the best response so we can understand the quality of"
|
198 |
+
" these different voice assistant models."
|
199 |
+
)
|
200 |
+
state = 3
|
201 |
+
return state
|
202 |
+
|
203 |
+
|
204 |
+
def clear_factory(button_id):
|
205 |
+
async def clear(audio_input, model_order, pref_counter, reasoning, latency):
|
206 |
+
textbox1 = gr.Textbox(visible=False)
|
207 |
+
textbox2 = gr.Textbox(visible=False)
|
208 |
+
if button_id != None:
|
209 |
+
sr, y = audio_input
|
210 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
211 |
+
await db.insert(
|
212 |
+
{
|
213 |
+
"audio_hash": x,
|
214 |
+
"outcome": button_id,
|
215 |
+
"model_a": model_shorthand[model_order[0]],
|
216 |
+
"model_b": model_shorthand[model_order[1]],
|
217 |
+
"why": reasoning,
|
218 |
+
"model_a_latency": latency[0],
|
219 |
+
"model_b_latency": latency[1],
|
220 |
+
}
|
221 |
+
)
|
222 |
+
pref_counter += 1
|
223 |
+
model_a = model_name[model_order[0]]
|
224 |
+
model_b = model_name[model_order[1]]
|
225 |
+
textbox1 = gr.Textbox(
|
226 |
+
visible=True,
|
227 |
+
info=f"<strong style='color: #53565A'>Response from {model_a}</strong><p>Time-to-First-Character: {latency[0]['time_to_first_token']:.2f} ms, Time Per Character: {latency[0]['total_time']/latency[0]['response_length']:.2f} ms</p>",
|
228 |
+
)
|
229 |
+
textbox2 = gr.Textbox(
|
230 |
+
visible=True,
|
231 |
+
info=f"<strong style='color: #53565A'>Response from {model_b}</strong><p>Time-to-First-Character: {latency[1]['time_to_first_token']:.2f} ms, Time Per Character: {latency[1]['total_time']/latency[1]['response_length']:.2f} ms</p>",
|
232 |
+
)
|
233 |
+
|
234 |
+
try:
|
235 |
+
sr, y = audio_input
|
236 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
237 |
+
os.remove(f"{x}.wav")
|
238 |
+
except:
|
239 |
+
# file already deleted, this is just a failsafe to assure data is cleared
|
240 |
+
pass
|
241 |
+
counter_text = f"# {pref_counter}/10 Preferences Submitted"
|
242 |
+
if pref_counter >= 10 and False: # Currently Disabled, Manages Prolific Completionx
|
243 |
+
code = "PLACEHOLDER"
|
244 |
+
counter_text = f"# Completed! Completion Code: {code}"
|
245 |
+
counter_text = ""
|
246 |
+
if anonymous:
|
247 |
+
model_order = random.sample(all_models, 2)
|
248 |
+
return (
|
249 |
+
model_order,
|
250 |
+
gr.Button(
|
251 |
+
value="Record Audio to Submit Again!",
|
252 |
+
interactive=False,
|
253 |
+
visible=True,
|
254 |
+
),
|
255 |
+
gr.Button(visible=False),
|
256 |
+
gr.Button(visible=False),
|
257 |
+
gr.Button(visible=False),
|
258 |
+
None,
|
259 |
+
textbox1,
|
260 |
+
textbox2,
|
261 |
+
pref_counter,
|
262 |
+
counter_text,
|
263 |
+
gr.Textbox(visible=False),
|
264 |
+
gr.Audio(visible=False),
|
265 |
+
)
|
266 |
+
|
267 |
+
return clear
|
268 |
+
|
269 |
+
|
270 |
+
def transcribe(transc, voice_reason):
|
271 |
+
if transc is None:
|
272 |
+
transc = ""
|
273 |
+
transc += " " + asr_pipe(voice_reason, generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
|
274 |
+
return transc, gr.Audio(value=None)
|
275 |
+
|
276 |
+
|
277 |
+
theme = gr.themes.Soft(
|
278 |
+
primary_hue=gr.themes.Color(
|
279 |
+
c100="#82000019",
|
280 |
+
c200="#82000033",
|
281 |
+
c300="#8200004c",
|
282 |
+
c400="#82000066",
|
283 |
+
c50="#8200007f",
|
284 |
+
c500="#8200007f",
|
285 |
+
c600="#82000099",
|
286 |
+
c700="#820000b2",
|
287 |
+
c800="#820000cc",
|
288 |
+
c900="#820000e5",
|
289 |
+
c950="#820000f2",
|
290 |
+
),
|
291 |
+
secondary_hue="rose",
|
292 |
+
neutral_hue="stone",
|
293 |
+
)
|
294 |
+
|
295 |
+
with open("src/talk_arena/styles.css", "r") as css_file:
|
296 |
+
custom_css = css_file.read()
|
297 |
+
|
298 |
+
db = TinyThreadSafeDB("live_votes.json")
|
299 |
+
|
300 |
+
with gr.Blocks(theme=theme, fill_height=True, css=custom_css) as demo:
|
301 |
+
submitted_preferences = gr.State(0)
|
302 |
+
state = gr.State(0)
|
303 |
+
model_order = gr.State([])
|
304 |
+
latency = gr.State([])
|
305 |
+
with gr.Row():
|
306 |
+
counter_text = gr.Markdown(
|
307 |
+
""
|
308 |
+
) # "# 0/10 Preferences Submitted.\n Follow the pop-up tips to submit your first preference.")
|
309 |
+
with gr.Row():
|
310 |
+
audio_input = gr.Audio(sources=["microphone"], streaming=False, label="Audio Input")
|
311 |
+
|
312 |
+
with gr.Row(equal_height=True):
|
313 |
+
with gr.Column(scale=1):
|
314 |
+
out1 = gr.Textbox(visible=False, lines=5, autoscroll=True)
|
315 |
+
with gr.Column(scale=1):
|
316 |
+
out2 = gr.Textbox(visible=False, lines=5, autoscroll=True)
|
317 |
+
|
318 |
+
with gr.Row():
|
319 |
+
btn = gr.Button(value="Record Audio to Submit!", interactive=False)
|
320 |
+
|
321 |
+
with gr.Row(equal_height=True):
|
322 |
+
reason = gr.Textbox(label="[Optional] Explain Your Preferences", visible=False, scale=4)
|
323 |
+
reason_record = gr.Audio(
|
324 |
+
sources=["microphone"],
|
325 |
+
interactive=True,
|
326 |
+
streaming=False,
|
327 |
+
label="Speak to transcribe!",
|
328 |
+
visible=False,
|
329 |
+
type="filepath",
|
330 |
+
# waveform_options={"show_recording_waveform": False},
|
331 |
+
scale=1,
|
332 |
+
)
|
333 |
+
|
334 |
+
with gr.Row():
|
335 |
+
best1 = gr.Button(value="Model 1 is better", visible=False)
|
336 |
+
tie = gr.Button(value="Tie", visible=False)
|
337 |
+
best2 = gr.Button(value="Model 2 is better", visible=False)
|
338 |
+
|
339 |
+
with gr.Row():
|
340 |
+
contact = gr.Markdown("")
|
341 |
+
|
342 |
+
# reason_record.stop_recording(transcribe, inputs=[reason, reason_record], outputs=[reason, reason_record])
|
343 |
+
audio_input.stop_recording(
|
344 |
+
recording_complete,
|
345 |
+
[state],
|
346 |
+
[btn, state],
|
347 |
+
).then(
|
348 |
+
fn=pairwise_response_async,
|
349 |
+
inputs=[audio_input, state, model_order],
|
350 |
+
outputs=[btn, out1, out2, best1, best2, tie, state, audio_input, reason, reason_record, latency],
|
351 |
+
)
|
352 |
+
audio_input.start_recording(
|
353 |
+
lambda: gr.Button(value="Uploading Audio to Cloud", interactive=False, variant="primary"),
|
354 |
+
None,
|
355 |
+
btn,
|
356 |
+
)
|
357 |
+
best1.click(
|
358 |
+
fn=clear_factory(0),
|
359 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
360 |
+
outputs=[
|
361 |
+
model_order,
|
362 |
+
btn,
|
363 |
+
best1,
|
364 |
+
best2,
|
365 |
+
tie,
|
366 |
+
audio_input,
|
367 |
+
out1,
|
368 |
+
out2,
|
369 |
+
submitted_preferences,
|
370 |
+
counter_text,
|
371 |
+
reason,
|
372 |
+
reason_record,
|
373 |
+
],
|
374 |
+
)
|
375 |
+
tie.click(
|
376 |
+
fn=clear_factory(0.5),
|
377 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
378 |
+
outputs=[
|
379 |
+
model_order,
|
380 |
+
btn,
|
381 |
+
best1,
|
382 |
+
best2,
|
383 |
+
tie,
|
384 |
+
audio_input,
|
385 |
+
out1,
|
386 |
+
out2,
|
387 |
+
submitted_preferences,
|
388 |
+
counter_text,
|
389 |
+
reason,
|
390 |
+
reason_record,
|
391 |
+
],
|
392 |
+
)
|
393 |
+
best2.click(
|
394 |
+
fn=clear_factory(1),
|
395 |
+
inputs=[audio_input, model_order, submitted_preferences, reason, latency],
|
396 |
+
outputs=[
|
397 |
+
model_order,
|
398 |
+
btn,
|
399 |
+
best1,
|
400 |
+
best2,
|
401 |
+
tie,
|
402 |
+
audio_input,
|
403 |
+
out1,
|
404 |
+
out2,
|
405 |
+
submitted_preferences,
|
406 |
+
counter_text,
|
407 |
+
reason,
|
408 |
+
reason_record,
|
409 |
+
],
|
410 |
+
)
|
411 |
+
audio_input.clear(
|
412 |
+
clear_factory(None),
|
413 |
+
[audio_input, model_order, submitted_preferences, reason, latency],
|
414 |
+
[
|
415 |
+
model_order,
|
416 |
+
btn,
|
417 |
+
best1,
|
418 |
+
best2,
|
419 |
+
tie,
|
420 |
+
audio_input,
|
421 |
+
out1,
|
422 |
+
out2,
|
423 |
+
submitted_preferences,
|
424 |
+
counter_text,
|
425 |
+
reason,
|
426 |
+
reason_record,
|
427 |
+
],
|
428 |
+
)
|
429 |
+
demo.load(fn=on_page_load, inputs=[state, model_order], outputs=[state, model_order])
|
430 |
+
|
431 |
+
if __name__ == "__main__":
|
432 |
+
demo.queue(default_concurrency_limit=40, api_open=False).launch(share=True, ssr_mode=False)
|
talk_arena/leaderboard_viz.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import textwrap
|
4 |
+
from collections import defaultdict
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import Dict, List, Tuple
|
7 |
+
from zoneinfo import ZoneInfo
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import plotly.express as px
|
13 |
+
import plotly.io as pio
|
14 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
15 |
+
from scipy.optimize import minimize
|
16 |
+
from scipy.special import expit
|
17 |
+
|
18 |
+
|
19 |
+
# Constants
|
20 |
+
COLORS = [
|
21 |
+
"#1B7FFF",
|
22 |
+
"#F07D1A",
|
23 |
+
"#BA24C7",
|
24 |
+
"#FE42C7",
|
25 |
+
"#0D4B7C",
|
26 |
+
"#0EAC96",
|
27 |
+
"#AA7CFF",
|
28 |
+
"#B50550",
|
29 |
+
"#009EEB",
|
30 |
+
"#220B55",
|
31 |
+
"#7B3301",
|
32 |
+
]
|
33 |
+
WR_PLOT = None
|
34 |
+
BT_PLOT = None
|
35 |
+
UPDATE_TIME = None
|
36 |
+
NAME_MAPPING = {
|
37 |
+
"gemini_2f": "Gemini 2.0 Flash (Experimental)",
|
38 |
+
"diva_3_8b": "DiVA Llama 3 8B",
|
39 |
+
"qwen2": "Qwen 2 Audio",
|
40 |
+
"pipe_l3.0": "Pipelined Llama 3 8B",
|
41 |
+
"gemini_1.5f": "Gemini 1.5 Flash",
|
42 |
+
"gpt4o": "GPT-4o",
|
43 |
+
"gemini_1.5p": "Gemini 1.5 Pro",
|
44 |
+
"typhoon_audio": "Typhoon Audio",
|
45 |
+
}
|
46 |
+
|
47 |
+
|
48 |
+
def get_aesthetic_timestamp():
|
49 |
+
"""
|
50 |
+
Returns a beautifully formatted timestamp in the format:
|
51 |
+
'Tuesday, December 10th, 2024 at 3:45 PM'
|
52 |
+
"""
|
53 |
+
# Get timezone object for PST
|
54 |
+
pst = ZoneInfo("America/Los_Angeles")
|
55 |
+
|
56 |
+
# Get current time in PST
|
57 |
+
now = datetime.now(pst)
|
58 |
+
|
59 |
+
# Add suffix to day number (1st, 2nd, 3rd, etc.)
|
60 |
+
day = now.day
|
61 |
+
if 4 <= day <= 20 or 24 <= day <= 30:
|
62 |
+
suffix = "th"
|
63 |
+
else:
|
64 |
+
suffix = ["st", "nd", "rd"][day % 10 - 1]
|
65 |
+
return now.strftime(f"%A, %B {day}{suffix}, %Y at %-I:%M %p")
|
66 |
+
|
67 |
+
|
68 |
+
def bootstrap_ci(data, n_bootstrap=10000, ci=95):
|
69 |
+
"""Calculate bootstrap confidence intervals."""
|
70 |
+
bootstrap_samples = []
|
71 |
+
for _ in range(n_bootstrap):
|
72 |
+
bootstrap_samples.append(np.mean(random.choices(data, k=len(data))))
|
73 |
+
lower_bound = np.percentile(bootstrap_samples, (100 - ci) / 2)
|
74 |
+
upper_bound = np.percentile(bootstrap_samples, 100 - (100 - ci) / 2)
|
75 |
+
return lower_bound, upper_bound
|
76 |
+
|
77 |
+
|
78 |
+
def calculate_win_rates(json_data):
|
79 |
+
"""Calculate win rates from JSON data."""
|
80 |
+
data = json.loads(json_data)
|
81 |
+
|
82 |
+
model_wins = defaultdict(int)
|
83 |
+
total_matches = defaultdict(int)
|
84 |
+
total_votes = 0
|
85 |
+
|
86 |
+
for value in data["_default"].values():
|
87 |
+
total_votes += 1
|
88 |
+
if value["outcome"] == 0:
|
89 |
+
model_wins[value["model_a"]] += 1
|
90 |
+
elif value["outcome"] == 1:
|
91 |
+
model_wins[value["model_b"]] += 1
|
92 |
+
elif value["outcome"] == 0.5:
|
93 |
+
model_wins[value["model_a"]] += 0.5
|
94 |
+
model_wins[value["model_b"]] += 0.5
|
95 |
+
total_matches[value["model_a"]] += 1
|
96 |
+
total_matches[value["model_b"]] += 1
|
97 |
+
|
98 |
+
per_model_wins = {}
|
99 |
+
for model, wins in model_wins.items():
|
100 |
+
win_rate = wins / total_matches[model]
|
101 |
+
wins_data = [1] * int(wins) + [0] * int(total_matches[model] - wins)
|
102 |
+
if int(wins) != wins:
|
103 |
+
wins_data += [0.5]
|
104 |
+
lower, upper = bootstrap_ci(wins_data)
|
105 |
+
per_model_wins[model] = {
|
106 |
+
"model": model,
|
107 |
+
"win_rate": win_rate,
|
108 |
+
"95_lower": (win_rate - lower),
|
109 |
+
"95_upper": (upper - win_rate),
|
110 |
+
}
|
111 |
+
df = pd.DataFrame.from_dict(per_model_wins).T
|
112 |
+
|
113 |
+
return df, total_votes
|
114 |
+
|
115 |
+
|
116 |
+
def create_win_rate_plot(wins_df):
|
117 |
+
"""Create win rate plot using Plotly."""
|
118 |
+
wins_df["Source"] = wins_df["Source"].astype(str)
|
119 |
+
wins_df = wins_df.sort_values(by=["Source", "win_rate"], ascending=False)
|
120 |
+
wins_df["model"] = wins_df["model"].apply(lambda x: NAME_MAPPING.get(x, x))
|
121 |
+
|
122 |
+
fig = px.bar(
|
123 |
+
wins_df,
|
124 |
+
x="model",
|
125 |
+
y="win_rate",
|
126 |
+
error_y="95_upper",
|
127 |
+
error_y_minus="95_lower",
|
128 |
+
color="model",
|
129 |
+
color_discrete_sequence=COLORS,
|
130 |
+
animation_group="model",
|
131 |
+
animation_frame="Source",
|
132 |
+
)
|
133 |
+
|
134 |
+
fig.update_traces(
|
135 |
+
hovertemplate="<b>%{x}</b><br>" + "Win Rate: %{y}" + "<extra></extra>",
|
136 |
+
)
|
137 |
+
|
138 |
+
fig.update_layout(
|
139 |
+
autosize=True,
|
140 |
+
showlegend=False,
|
141 |
+
plot_bgcolor="white",
|
142 |
+
title={
|
143 |
+
"text": "Talk Arena Live Win Rates<br>with 95% Confidence Intervals",
|
144 |
+
"y": 0.95,
|
145 |
+
"x": 0.5,
|
146 |
+
"xanchor": "center",
|
147 |
+
"yanchor": "top",
|
148 |
+
},
|
149 |
+
xaxis_title="Model",
|
150 |
+
yaxis_title="Win Rate (%)",
|
151 |
+
bargap=0.2,
|
152 |
+
yaxis=dict(
|
153 |
+
tickformat=".0%", tickmode="auto", range=[0, 1.01], gridcolor="#C9CCD1", griddash="dash", gridwidth=2
|
154 |
+
),
|
155 |
+
legend=dict(
|
156 |
+
orientation="h", # Make legend horizontal
|
157 |
+
yanchor="bottom",
|
158 |
+
y=-0.5, # Position below plot
|
159 |
+
xanchor="center",
|
160 |
+
x=0.5, # Center horizontally
|
161 |
+
bgcolor="rgba(255, 255, 255, 0.8)",
|
162 |
+
bordercolor="#C9CCD1",
|
163 |
+
borderwidth=1,
|
164 |
+
),
|
165 |
+
margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
|
166 |
+
hoverlabel=dict(bgcolor="white", font_size=14, bordercolor="gray"),
|
167 |
+
)
|
168 |
+
|
169 |
+
fig.update_xaxes(showgrid=False)
|
170 |
+
|
171 |
+
return fig
|
172 |
+
|
173 |
+
|
174 |
+
# Bradley-Terry Model Functions
|
175 |
+
def load_live_votes(json_str: str) -> pd.DataFrame:
|
176 |
+
"""Load and preprocess live votes data from JSON string."""
|
177 |
+
data = json.loads(json_str)
|
178 |
+
df = pd.DataFrame.from_dict(data["_default"], orient="index")
|
179 |
+
df["winner"] = df["outcome"].map({1: "model_b", 0: "model_a", 0.5: "tie"})
|
180 |
+
return df
|
181 |
+
|
182 |
+
|
183 |
+
def preprocess_for_bt(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str], np.ndarray]:
|
184 |
+
"""Preprocess data for Bradley-Terry model fitting."""
|
185 |
+
all_models = pd.concat([df["model_a"], df["model_b"]]).unique()
|
186 |
+
model_to_idx = {model: idx for idx, model in enumerate(all_models)}
|
187 |
+
|
188 |
+
matchups = np.array([[model_to_idx[row.model_a], model_to_idx[row.model_b]] for _, row in df.iterrows()])
|
189 |
+
|
190 |
+
outcomes = np.array(
|
191 |
+
[1.0 if row.winner == "model_a" else (0.5 if row.winner == "tie" else 0.0) for _, row in df.iterrows()]
|
192 |
+
)
|
193 |
+
|
194 |
+
unique_matches = np.column_stack([matchups, outcomes])
|
195 |
+
unique_matches, weights = np.unique(unique_matches, return_counts=True, axis=0)
|
196 |
+
|
197 |
+
return (unique_matches[:, :2].astype(np.int32), unique_matches[:, 2], list(all_models), weights.astype(np.float64))
|
198 |
+
|
199 |
+
|
200 |
+
def bt_loss_and_grad(
|
201 |
+
ratings: np.ndarray, matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, alpha: float = 1.0
|
202 |
+
) -> Tuple[float, np.ndarray]:
|
203 |
+
"""Compute Bradley-Terry loss and gradient."""
|
204 |
+
matchup_ratings = ratings[matchups]
|
205 |
+
logits = alpha * (matchup_ratings[:, 0] - matchup_ratings[:, 1])
|
206 |
+
probs = expit(logits)
|
207 |
+
|
208 |
+
loss = -((np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights).sum()
|
209 |
+
|
210 |
+
matchups_grads = -alpha * (outcomes - probs) * weights
|
211 |
+
model_grad = np.zeros_like(ratings)
|
212 |
+
np.add.at(model_grad, matchups[:, [0, 1]], matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64))
|
213 |
+
|
214 |
+
return loss, model_grad
|
215 |
+
|
216 |
+
|
217 |
+
def fit_bt(
|
218 |
+
matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, n_models: int, alpha: float, tol: float = 1e-6
|
219 |
+
) -> np.ndarray:
|
220 |
+
"""Fit Bradley-Terry model using L-BFGS-B optimization."""
|
221 |
+
initial_ratings = np.zeros(n_models, dtype=np.float64)
|
222 |
+
|
223 |
+
result = minimize(
|
224 |
+
fun=bt_loss_and_grad,
|
225 |
+
x0=initial_ratings,
|
226 |
+
args=(matchups, outcomes, weights, alpha),
|
227 |
+
jac=True,
|
228 |
+
method="L-BFGS-B",
|
229 |
+
options={"disp": False, "maxiter": 100, "gtol": tol},
|
230 |
+
)
|
231 |
+
|
232 |
+
return result["x"]
|
233 |
+
|
234 |
+
|
235 |
+
def scale_and_offset(
|
236 |
+
ratings: np.ndarray, models: List[str], scale: float = 400, init_rating: float = 1000
|
237 |
+
) -> np.ndarray:
|
238 |
+
"""Scale ratings to familiar Elo-like scale."""
|
239 |
+
scaled_ratings = (ratings * scale) + init_rating
|
240 |
+
return scaled_ratings
|
241 |
+
|
242 |
+
|
243 |
+
def compute_bootstrap_bt(
|
244 |
+
data: str,
|
245 |
+
num_round: int = 100,
|
246 |
+
base: float = 10.0,
|
247 |
+
scale: float = 400.0,
|
248 |
+
init_rating: float = 1000.0,
|
249 |
+
tol: float = 1e-6,
|
250 |
+
) -> pd.DataFrame:
|
251 |
+
"""Compute bootstrap Bradley-Terry ratings from live votes data."""
|
252 |
+
df = load_live_votes(data)
|
253 |
+
matchups, outcomes, models, weights = preprocess_for_bt(df)
|
254 |
+
|
255 |
+
rng = np.random.default_rng(seed=0)
|
256 |
+
total_matches = len(df)
|
257 |
+
idxs = rng.multinomial(n=total_matches, pvals=weights / weights.sum(), size=num_round)
|
258 |
+
boot_weights = idxs.astype(np.float64) / total_matches
|
259 |
+
|
260 |
+
ratings_list = []
|
261 |
+
for sample_weights in boot_weights:
|
262 |
+
ratings = fit_bt(
|
263 |
+
matchups=matchups,
|
264 |
+
outcomes=outcomes,
|
265 |
+
weights=sample_weights,
|
266 |
+
n_models=len(models),
|
267 |
+
alpha=np.log(base),
|
268 |
+
tol=tol,
|
269 |
+
)
|
270 |
+
scaled_ratings = scale_and_offset(ratings=ratings, models=models, scale=scale, init_rating=init_rating)
|
271 |
+
ratings_list.append(scaled_ratings)
|
272 |
+
|
273 |
+
df_ratings = pd.DataFrame(ratings_list, columns=models)
|
274 |
+
return df_ratings[df_ratings.median().sort_values(ascending=False).index]
|
275 |
+
|
276 |
+
|
277 |
+
def create_bt_plot(bootstrap_ratings):
|
278 |
+
"""Create Bradley-Terry ratings plot using Plotly."""
|
279 |
+
melted_bootstrap = bootstrap_ratings.melt(id_vars=["Source", "level_1"], var_name="Model", value_name="BT")
|
280 |
+
melted_bootstrap = melted_bootstrap.dropna()
|
281 |
+
melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "Model", "BT"], ascending=False)
|
282 |
+
# Pretty Names
|
283 |
+
melted_bootstrap["Model"] = melted_bootstrap["Model"].apply(lambda x: NAME_MAPPING.get(x, x))
|
284 |
+
# Compression for Client Side
|
285 |
+
melted_bootstrap["BT"] = melted_bootstrap["BT"].apply(lambda x: int(x))
|
286 |
+
min_samp = melted_bootstrap[melted_bootstrap["BT"] > 0]["BT"].min()
|
287 |
+
max_samp = melted_bootstrap["BT"].max()
|
288 |
+
idx_keep = list(range(0, len(melted_bootstrap), 10))
|
289 |
+
melted_bootstrap = melted_bootstrap.iloc[idx_keep]
|
290 |
+
melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "BT"], ascending=False)
|
291 |
+
fig = px.violin(
|
292 |
+
melted_bootstrap,
|
293 |
+
x="Model",
|
294 |
+
y="BT",
|
295 |
+
color="Model",
|
296 |
+
animation_group="Model",
|
297 |
+
animation_frame="Source",
|
298 |
+
color_discrete_sequence=COLORS,
|
299 |
+
)
|
300 |
+
|
301 |
+
fig.update_layout(
|
302 |
+
autosize=True,
|
303 |
+
showlegend=False,
|
304 |
+
plot_bgcolor="white",
|
305 |
+
title={
|
306 |
+
"text": "Talk Arena Live Bradley-Terry Ratings<br>with Bootstrapped Variance",
|
307 |
+
"y": 0.9,
|
308 |
+
"x": 0.5,
|
309 |
+
"xanchor": "center",
|
310 |
+
"yanchor": "top",
|
311 |
+
},
|
312 |
+
xaxis_title="Model",
|
313 |
+
yaxis_title="Rating",
|
314 |
+
yaxis=dict(gridcolor="#C9CCD1", range=[min_samp - 10, max_samp + 10], griddash="dash"),
|
315 |
+
legend=dict(
|
316 |
+
orientation="h", # Make legend horizontal
|
317 |
+
yanchor="bottom",
|
318 |
+
y=-0.5, # Position below plot
|
319 |
+
xanchor="center",
|
320 |
+
x=0.5, # Center horizontally
|
321 |
+
bgcolor="rgba(255, 255, 255, 0.8)",
|
322 |
+
bordercolor="#C9CCD1",
|
323 |
+
borderwidth=1,
|
324 |
+
),
|
325 |
+
margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
|
326 |
+
)
|
327 |
+
|
328 |
+
fig.update_xaxes(showgrid=False)
|
329 |
+
fig.update_yaxes(showgrid=True, gridwidth=2)
|
330 |
+
|
331 |
+
return fig
|
332 |
+
|
333 |
+
|
334 |
+
def get_wr_plot():
|
335 |
+
jrep = json.loads(pio.to_json(WR_PLOT))
|
336 |
+
for step in jrep["layout"]["sliders"][0]["steps"]:
|
337 |
+
step["args"][1]["frame"]["duration"] = 500
|
338 |
+
step["args"][1]["transition"]["duration"] = 500
|
339 |
+
jrep["layout"]["updatemenus"] = []
|
340 |
+
jrep["layout"]["sliders"][0]["len"] = 0.8
|
341 |
+
jrep["layout"]["sliders"][0]["pad"] = {}
|
342 |
+
return json.dumps(jrep)
|
343 |
+
|
344 |
+
|
345 |
+
def get_bt_plot():
|
346 |
+
jrep = json.loads(pio.to_json(BT_PLOT))
|
347 |
+
for step in jrep["layout"]["sliders"][0]["steps"]:
|
348 |
+
step["args"][1]["frame"]["duration"] = 500
|
349 |
+
step["args"][1]["transition"]["duration"] = 500
|
350 |
+
jrep["layout"]["updatemenus"] = []
|
351 |
+
jrep["layout"]["sliders"][0]["len"] = 0.8
|
352 |
+
jrep["layout"]["sliders"][0]["pad"] = {}
|
353 |
+
return json.dumps(jrep)
|
354 |
+
|
355 |
+
|
356 |
+
def get_update_time():
|
357 |
+
return UPDATE_TIME
|
358 |
+
|
359 |
+
|
360 |
+
def viz_factory(force=False):
|
361 |
+
def process_and_visualize():
|
362 |
+
"""Main function to process JSON data and create visualizations."""
|
363 |
+
global WR_PLOT, BT_PLOT, UPDATE_TIME
|
364 |
+
if WR_PLOT is not None and BT_PLOT is not None and not force:
|
365 |
+
return WR_PLOT, BT_PLOT, UPDATE_TIME
|
366 |
+
try:
|
367 |
+
# Read JSON data
|
368 |
+
pub_json_data = open("/home/wheld3/talk-arena/live_votes.json", "r").read()
|
369 |
+
prolific_json_data = open("/home/wheld3/talk-arena/prolific_votes.json", "r").read()
|
370 |
+
merged_json_data = json.dumps(
|
371 |
+
{"_default": {**json.loads(pub_json_data)["_default"], **json.loads(prolific_json_data)["_default"]}}
|
372 |
+
)
|
373 |
+
# Calculate win rates and create win rate plot
|
374 |
+
pub_win_rates, pub_votes = calculate_win_rates(pub_json_data)
|
375 |
+
pro_win_rates, pro_votes = calculate_win_rates(prolific_json_data)
|
376 |
+
total_win_rates, total_votes = calculate_win_rates(merged_json_data)
|
377 |
+
all_models = total_win_rates["model"].unique()
|
378 |
+
pro_models = pro_win_rates["model"].unique()
|
379 |
+
for model in all_models:
|
380 |
+
if model not in pro_models:
|
381 |
+
new_index = len(pro_win_rates)
|
382 |
+
pro_win_rates.loc[new_index] = [model, -0.1, -0.1, -0.2]
|
383 |
+
|
384 |
+
win_rates = (
|
385 |
+
pd.concat([pub_win_rates, pro_win_rates, total_win_rates], keys=["Public", "Prolific", "Total"])
|
386 |
+
.reset_index()
|
387 |
+
.rename(columns={"level_0": "Source"})
|
388 |
+
)
|
389 |
+
WR_PLOT = create_win_rate_plot(win_rates)
|
390 |
+
|
391 |
+
# Calculate Bradley-Terry ratings and create BT plot
|
392 |
+
pub_bootstrap_ratings = compute_bootstrap_bt(pub_json_data, num_round=10000)
|
393 |
+
pro_bootstrap_ratings = compute_bootstrap_bt(prolific_json_data, num_round=10000)
|
394 |
+
total_bootstrap_ratings = compute_bootstrap_bt(merged_json_data, num_round=10000)
|
395 |
+
for model in all_models:
|
396 |
+
if model not in pro_models:
|
397 |
+
pro_bootstrap_ratings[model] = pro_bootstrap_ratings["diva_3_8b"] * -1
|
398 |
+
|
399 |
+
bootstrap_ratings = (
|
400 |
+
pd.concat(
|
401 |
+
[pub_bootstrap_ratings, pro_bootstrap_ratings, total_bootstrap_ratings],
|
402 |
+
keys=["Public", "Prolific", "Total"],
|
403 |
+
)
|
404 |
+
.reset_index()
|
405 |
+
.rename(columns={"level_0": "Source"})
|
406 |
+
)
|
407 |
+
BT_PLOT = create_bt_plot(bootstrap_ratings)
|
408 |
+
UPDATE_TIME = gr.Markdown(
|
409 |
+
value=textwrap.dedent(
|
410 |
+
f"""
|
411 |
+
<h4 class="nx-font-semibold nx-tracking-tight nx-text-slate-900 dark:nx-text-slate-100 nx-text-xl">Last Refresh: {get_aesthetic_timestamp()} PST</h4>
|
412 |
+
<h6 class="nx-font-semibold nx-tracking-tight nx-text-slate-900 dark:nx-text-slate-100 nx-text-base">Total Votes: {total_votes}, Public Votes: {pub_votes}, Prolific Votes: {pro_votes}</h6>
|
413 |
+
"""
|
414 |
+
)
|
415 |
+
)
|
416 |
+
return WR_PLOT, BT_PLOT, UPDATE_TIME
|
417 |
+
|
418 |
+
except Exception as e:
|
419 |
+
raise gr.Error(f"Error processing file: {str(e)}")
|
420 |
+
|
421 |
+
return process_and_visualize
|
422 |
+
|
423 |
+
|
424 |
+
theme = gr.themes.Soft(
|
425 |
+
primary_hue=gr.themes.Color(
|
426 |
+
c100="#82000019",
|
427 |
+
c200="#82000033",
|
428 |
+
c300="#8200004c",
|
429 |
+
c400="#82000066",
|
430 |
+
c50="#8200007f",
|
431 |
+
c500="#8200007f",
|
432 |
+
c600="#82000099",
|
433 |
+
c700="#820000b2",
|
434 |
+
c800="#820000cc",
|
435 |
+
c900="#820000e5",
|
436 |
+
c950="#820000f2",
|
437 |
+
),
|
438 |
+
secondary_hue="rose",
|
439 |
+
neutral_hue="stone",
|
440 |
+
)
|
441 |
+
|
442 |
+
# Create Gradio interface
|
443 |
+
with gr.Blocks(title="Talk Arena Leaderboard Analysis", theme=theme) as demo:
|
444 |
+
viz_factory(force=True)()
|
445 |
+
last_updated = UPDATE_TIME
|
446 |
+
with gr.Row():
|
447 |
+
bt_plot = gr.Plot(label="Bradley-Terry Ratings", value=BT_PLOT)
|
448 |
+
with gr.Row():
|
449 |
+
win_rate_plot = gr.Plot(label="Win Rates", value=WR_PLOT)
|
450 |
+
|
451 |
+
d1 = gr.Textbox(visible=False)
|
452 |
+
demo.load(
|
453 |
+
fn=viz_factory(force=False), inputs=[], outputs=[win_rate_plot, bt_plot, last_updated], show_progress="minimal"
|
454 |
+
)
|
455 |
+
demo.load(fn=get_wr_plot, inputs=[], outputs=[d1])
|
456 |
+
demo.load(fn=get_bt_plot, inputs=[], outputs=[d1])
|
457 |
+
demo.load(fn=get_update_time, inputs=[], outputs=[d1])
|
458 |
+
|
459 |
+
if __name__ == "__main__":
|
460 |
+
scheduler = BackgroundScheduler()
|
461 |
+
scheduler.add_job(func=viz_factory(force=True), trigger="interval", seconds=300)
|
462 |
+
scheduler.start()
|
463 |
+
demo.queue(default_concurrency_limit=10, api_open=True).launch(share=True, server_port=8004, node_port=8005)
|
talk_arena/streaming_helpers.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import base64
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from collections import defaultdict
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import google.generativeai as genai
|
9 |
+
import gradio as gr
|
10 |
+
import librosa
|
11 |
+
import numpy as np
|
12 |
+
import soundfile as sf
|
13 |
+
import torch
|
14 |
+
import xxhash
|
15 |
+
from datasets import Audio
|
16 |
+
from openai import AsyncOpenAI
|
17 |
+
from transformers import AutoModel, AutoProcessor, Qwen2AudioForConditionalGeneration, TextIteratorStreamer
|
18 |
+
from transformers.generation import GenerationConfig
|
19 |
+
|
20 |
+
|
21 |
+
def _get_prompt_for_model_name(model_id):
|
22 |
+
prompt_dict = defaultdict(lambda: "You are a helpful assistant. Respond conversationally to the speech provided.")
|
23 |
+
# Requested Overrides
|
24 |
+
prompt_dict["scb10x/llama-3-typhoon-audio-8b-2411"] = (
|
25 |
+
"You are a helpful assistant. Respond conversationally to the speech provided in the language it is spoken in."
|
26 |
+
)
|
27 |
+
return prompt_dict[model_id]
|
28 |
+
|
29 |
+
|
30 |
+
def _get_config_for_model_name(model_id):
|
31 |
+
if "API_MODEL_CONFIG" in os.environ:
|
32 |
+
return json.loads(os.environ["API_MODEL_CONFIG"])[model_id]
|
33 |
+
return {
|
34 |
+
"pipeline/meta-llama/Meta-Llama-3-8B-Instruct": {"base_url": "http://localhost:8001/v1", "api_key": "empty"},
|
35 |
+
"scb10x/llama-3-typhoon-audio-8b-2411": {
|
36 |
+
"base_url": "http://localhost:8002/v1",
|
37 |
+
"api_key": "empty",
|
38 |
+
},
|
39 |
+
"WillHeld/DiVA-llama-3-v0-8b": {
|
40 |
+
"base_url": "http://localhost:8003/v1",
|
41 |
+
"api_key": "empty",
|
42 |
+
},
|
43 |
+
"Qwen/Qwen2-Audio-7B-Instruct": {
|
44 |
+
"base_url": "http://localhost:8004/v1",
|
45 |
+
"api_key": "empty",
|
46 |
+
},
|
47 |
+
}[model_id]
|
48 |
+
|
49 |
+
|
50 |
+
def gradio_gen_factory(streaming_fn, model_name, anonymous):
|
51 |
+
async def gen_from(audio_input, order):
|
52 |
+
with torch.no_grad():
|
53 |
+
prev_resp = ""
|
54 |
+
async for resp in streaming_fn(audio_input):
|
55 |
+
for char in range(len(prev_resp), len(resp)):
|
56 |
+
my_resp = gr.Textbox(
|
57 |
+
value=resp[: char + 1],
|
58 |
+
info="",
|
59 |
+
visible=True,
|
60 |
+
label=model_name if not anonymous else f"Model {order+1}",
|
61 |
+
elem_classes="lam-response-box",
|
62 |
+
)
|
63 |
+
yield my_resp
|
64 |
+
await asyncio.sleep(0.001)
|
65 |
+
prev_resp = resp
|
66 |
+
|
67 |
+
return gen_from
|
68 |
+
|
69 |
+
|
70 |
+
def gemini_streaming(model_id):
|
71 |
+
genai.configure(api_key=os.environ["GEMINI_API_KEY"])
|
72 |
+
resampler = Audio(sampling_rate=16_000)
|
73 |
+
model = genai.GenerativeModel(model_id)
|
74 |
+
|
75 |
+
async def get_chat_response(audio_input):
|
76 |
+
if audio_input is None:
|
77 |
+
raise StopAsyncIteration("")
|
78 |
+
sr, y = audio_input
|
79 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
80 |
+
y = y.astype(np.float32)
|
81 |
+
y /= np.max(np.abs(y))
|
82 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
83 |
+
sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
|
84 |
+
prompt = "You are a helpful assistant. Respond conversationally to the speech provided."
|
85 |
+
inputs = [prompt, {"mime_type": "audio/wav", "data": Path(f"{x}.wav").read_bytes()}]
|
86 |
+
text_response = []
|
87 |
+
responses = model.generate_content(inputs, stream=True)
|
88 |
+
for chunk in responses:
|
89 |
+
text_response.append(chunk.text)
|
90 |
+
yield "".join(text_response)
|
91 |
+
os.remove(f"{x}.wav")
|
92 |
+
|
93 |
+
return get_chat_response, model
|
94 |
+
|
95 |
+
|
96 |
+
def gpt4o_streaming(model_id):
|
97 |
+
client = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
98 |
+
resampler = Audio(sampling_rate=16_000)
|
99 |
+
|
100 |
+
async def get_chat_response(audio_input):
|
101 |
+
if audio_input is None:
|
102 |
+
raise StopAsyncIteration("")
|
103 |
+
sr, y = audio_input
|
104 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
105 |
+
y = y.astype(np.float32)
|
106 |
+
y /= np.max(np.abs(y))
|
107 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
108 |
+
sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
|
109 |
+
with open(f"{x}.wav", "rb") as wav_file:
|
110 |
+
wav_data = wav_file.read()
|
111 |
+
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
112 |
+
prompt = "You are a helpful assistant. Respond conversationally to the speech provided."
|
113 |
+
try:
|
114 |
+
completion = await client.chat.completions.create(
|
115 |
+
model="gpt-4o-audio-preview",
|
116 |
+
modalities=["text", "audio"],
|
117 |
+
audio={"voice": "alloy", "format": "wav"},
|
118 |
+
messages=[
|
119 |
+
{
|
120 |
+
"role": "user",
|
121 |
+
"content": [
|
122 |
+
{"type": "text", "text": prompt},
|
123 |
+
{"type": "input_audio", "input_audio": {"data": encoded_string, "format": "wav"}},
|
124 |
+
],
|
125 |
+
},
|
126 |
+
],
|
127 |
+
)
|
128 |
+
os.remove(f"{x}.wav")
|
129 |
+
yield completion.choices[0].message.audio.transcript
|
130 |
+
except:
|
131 |
+
raise StopAsyncIteration("error")
|
132 |
+
|
133 |
+
return get_chat_response, client
|
134 |
+
|
135 |
+
|
136 |
+
async def llm_streaming(model_id: str, prompt: str):
|
137 |
+
if "gpt" in model_id:
|
138 |
+
client = AsyncOpenAI()
|
139 |
+
else:
|
140 |
+
client = AsyncOpenAI(**_get_config_for_model_name(model_id))
|
141 |
+
try:
|
142 |
+
completion = await client.chat.completions.create(
|
143 |
+
model=model_id,
|
144 |
+
messages=[
|
145 |
+
{"role": "system", "content": "You are helpful assistant."},
|
146 |
+
{
|
147 |
+
"role": "user",
|
148 |
+
"content": prompt,
|
149 |
+
},
|
150 |
+
],
|
151 |
+
stream=True,
|
152 |
+
)
|
153 |
+
text_response = []
|
154 |
+
async for chunk in completion:
|
155 |
+
if len(chunk.choices) > 0:
|
156 |
+
text_response.append(chunk.choices[0].delta.content)
|
157 |
+
yield "".join(text_response)
|
158 |
+
except:
|
159 |
+
raise StopAsyncIteration("error")
|
160 |
+
|
161 |
+
|
162 |
+
def asr_streaming(model_id, asr_pipe):
|
163 |
+
resampler = Audio(sampling_rate=16_000)
|
164 |
+
|
165 |
+
async def pipelined(audio_input):
|
166 |
+
if audio_input is None:
|
167 |
+
raise StopAsyncIteration("")
|
168 |
+
sr, y = audio_input
|
169 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
170 |
+
y = y.astype(np.float32)
|
171 |
+
y /= np.max(np.abs(y))
|
172 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
173 |
+
sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
|
174 |
+
text = await asyncio.to_thread(
|
175 |
+
asr_pipe(f"{x}.wav", generate_kwargs={"task": "transcribe"}, return_timestamps=False)["text"]
|
176 |
+
)
|
177 |
+
os.remove(f"{x}.wav")
|
178 |
+
async for response in llm_streaming(model_id, prompt=text):
|
179 |
+
yield response
|
180 |
+
|
181 |
+
return pipelined
|
182 |
+
|
183 |
+
|
184 |
+
def api_streaming(model_id):
|
185 |
+
client = AsyncOpenAI(**_get_config_for_model_name(model_id))
|
186 |
+
resampler = Audio(sampling_rate=16_000)
|
187 |
+
|
188 |
+
async def get_chat_response(audio_input):
|
189 |
+
if audio_input is None:
|
190 |
+
raise StopAsyncIteration("")
|
191 |
+
sr, y = audio_input
|
192 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
193 |
+
y = y.astype(np.float32)
|
194 |
+
y /= np.max(np.abs(y))
|
195 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
196 |
+
sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
|
197 |
+
with open(f"{x}.wav", "rb") as wav_file:
|
198 |
+
wav_data = wav_file.read()
|
199 |
+
encoded_string = base64.b64encode(wav_data).decode("utf-8")
|
200 |
+
try:
|
201 |
+
prompt = _get_prompt_for_model_name(model_id)
|
202 |
+
completion = await client.chat.completions.create(
|
203 |
+
model=model_id,
|
204 |
+
messages=[
|
205 |
+
{
|
206 |
+
"role": "user",
|
207 |
+
"content": [
|
208 |
+
{"type": "text", "text": prompt},
|
209 |
+
{"type": "audio", "audio_url": "data:audio/wav;base64," + encoded_string},
|
210 |
+
],
|
211 |
+
},
|
212 |
+
],
|
213 |
+
stream=True,
|
214 |
+
)
|
215 |
+
text_response = []
|
216 |
+
async for chunk in completion:
|
217 |
+
if len(chunk.choices) > 0:
|
218 |
+
text_response.append(chunk.choices[0].delta.content)
|
219 |
+
yield "".join(text_response)
|
220 |
+
os.remove(f"{x}.wav")
|
221 |
+
except:
|
222 |
+
print(f"error for {model_id}")
|
223 |
+
raise StopAsyncIteration(f"error for {model_id}")
|
224 |
+
|
225 |
+
return get_chat_response, client
|
226 |
+
|
227 |
+
|
228 |
+
# Local Hosting Utilities
|
229 |
+
|
230 |
+
|
231 |
+
def diva_streaming(diva_model_str):
|
232 |
+
diva_model = AutoModel.from_pretrained(diva_model_str, trust_remote_code=True, device_map="balanced_low_0")
|
233 |
+
resampler = Audio(sampling_rate=16_000)
|
234 |
+
|
235 |
+
async def diva_audio(audio_input, do_sample=False, temperature=0.001):
|
236 |
+
sr, y = audio_input
|
237 |
+
y = y.astype(np.float32)
|
238 |
+
y /= np.max(np.abs(y))
|
239 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
240 |
+
stream = diva_model.generate_stream(
|
241 |
+
a["array"],
|
242 |
+
(
|
243 |
+
"You are a helpful assistant The user is talking to you with their voice and you are responding with"
|
244 |
+
" text."
|
245 |
+
),
|
246 |
+
do_sample=do_sample,
|
247 |
+
max_new_tokens=256,
|
248 |
+
)
|
249 |
+
for text in stream:
|
250 |
+
yield text
|
251 |
+
|
252 |
+
return diva_audio, diva_model
|
253 |
+
|
254 |
+
|
255 |
+
def qwen2_streaming(qwen2_model_str):
|
256 |
+
resampler = Audio(sampling_rate=16_000)
|
257 |
+
qwen2_processor = AutoProcessor.from_pretrained(qwen2_model_str)
|
258 |
+
qwen2_model = Qwen2AudioForConditionalGeneration.from_pretrained(qwen2_model_str, device_map="auto")
|
259 |
+
qwen2_model.generation_config = GenerationConfig.from_pretrained(
|
260 |
+
qwen2_model_str,
|
261 |
+
trust_remote_code=True,
|
262 |
+
do_sample=False,
|
263 |
+
top_k=50,
|
264 |
+
top_p=1.0,
|
265 |
+
)
|
266 |
+
|
267 |
+
async def qwen2_audio(audio_input, do_sample=False, temperature=0.001):
|
268 |
+
if audio_input is None:
|
269 |
+
raise StopAsyncIteration("")
|
270 |
+
sr, y = audio_input
|
271 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
272 |
+
y = y.astype(np.float32)
|
273 |
+
y /= np.max(np.abs(y))
|
274 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
275 |
+
sf.write(f"{x}.wav", a["array"], a["sampling_rate"], format="wav")
|
276 |
+
conversation = [
|
277 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
278 |
+
{
|
279 |
+
"role": "user",
|
280 |
+
"content": [
|
281 |
+
{
|
282 |
+
"type": "audio",
|
283 |
+
"audio_url": f"{x}.wav",
|
284 |
+
},
|
285 |
+
],
|
286 |
+
},
|
287 |
+
]
|
288 |
+
text = qwen2_processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
289 |
+
audios = [librosa.load(f"{x}.wav", sr=qwen2_processor.feature_extractor.sampling_rate)[0]]
|
290 |
+
inputs = qwen2_processor(text=text, audios=audios, return_tensors="pt", padding=True)
|
291 |
+
streamer = TextIteratorStreamer(qwen2_processor)
|
292 |
+
generation_task = asyncio.create_task(qwen2_model.generate(**inputs, streamer=streamer, max_length=256))
|
293 |
+
|
294 |
+
generated_text = ""
|
295 |
+
async for new_text in streamer:
|
296 |
+
generated_text += new_text
|
297 |
+
yield generated_text.split("<|im_start|>assistant\n")[-1].replace("<|im_end|>", "")
|
298 |
+
|
299 |
+
await generation_task
|
300 |
+
os.remove(f"{x}.wav")
|
301 |
+
|
302 |
+
return qwen2_audio, qwen2_model
|
303 |
+
|
304 |
+
|
305 |
+
def typhoon_streaming(typhoon_model_str, device="cuda:0"):
|
306 |
+
resampler = Audio(sampling_rate=16_000)
|
307 |
+
typhoon_model = AutoModel.from_pretrained(typhoon_model_str, torch_dtype=torch.float16, trust_remote_code=True)
|
308 |
+
tokenizer = typhoon_model.llama_tokenizer
|
309 |
+
typhoon_model.to(device)
|
310 |
+
typhoon_model.eval()
|
311 |
+
prompt_pattern = (
|
312 |
+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<Speech><SpeechHere></Speech>"
|
313 |
+
" {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
314 |
+
)
|
315 |
+
prompt = (
|
316 |
+
"You are a helpful assistant. Respond conversationally to the speech provided in the language it is spoken in."
|
317 |
+
)
|
318 |
+
|
319 |
+
async def typhoon_audio(audio_input, do_sample=False, temperature=0.001):
|
320 |
+
if audio_input == None:
|
321 |
+
raise StopAsyncIteration("")
|
322 |
+
sr, y = audio_input
|
323 |
+
x = xxhash.xxh32(bytes(y)).hexdigest()
|
324 |
+
y = y.astype(np.float32)
|
325 |
+
y /= np.max(np.abs(y))
|
326 |
+
a = resampler.decode_example(resampler.encode_example({"array": y, "sampling_rate": sr}))
|
327 |
+
streamer = TextIteratorStreamer(tokenizer)
|
328 |
+
generation_task = asyncio.create_task(
|
329 |
+
typhoon_model.generate(
|
330 |
+
audio=a["array"],
|
331 |
+
prompt=prompt,
|
332 |
+
prompt_pattern=prompt_pattern,
|
333 |
+
device=device,
|
334 |
+
do_sample=False,
|
335 |
+
max_length=1200,
|
336 |
+
num_beams=1,
|
337 |
+
streamer=streamer, # supports TextIteratorStreamer
|
338 |
+
)
|
339 |
+
)
|
340 |
+
generated_text = ""
|
341 |
+
async for new_text in streamer:
|
342 |
+
generated_text += new_text
|
343 |
+
yield generated_text.split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].replace(
|
344 |
+
"<|eot_id|>", ""
|
345 |
+
)
|
346 |
+
await generation_task
|
347 |
+
|
348 |
+
return typhoon_audio, typhoon_model
|
talk_arena/styles.css
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
@media (max-width: 768px) {
|
2 |
+
.lam-response-box {
|
3 |
+
max-height: 230px;
|
4 |
+
}
|
5 |
+
|
6 |
+
.lam-response-box > label {
|
7 |
+
max-height: 100%;
|
8 |
+
}
|
9 |
+
|
10 |
+
.lam-response-box > label > div > textarea{
|
11 |
+
max-height: 100%;
|
12 |
+
height: 100% !important;
|
13 |
+
}
|
14 |
+
}
|
15 |
+
|
16 |
+
@media (min-width: 769px) {
|
17 |
+
.lam-response-box {
|
18 |
+
max-height: 40vh;
|
19 |
+
}
|
20 |
+
|
21 |
+
.lam-response-box > label > div > textarea{
|
22 |
+
max-height: calc(40vh - 50px) !important;
|
23 |
+
}
|
24 |
+
}
|
25 |
+
|
talk_arena/viz/core.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import textwrap
|
4 |
+
from collections import defaultdict
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import Dict, List, Tuple
|
7 |
+
from zoneinfo import ZoneInfo
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
import plotly.io as pio
|
13 |
+
from scipy.optimize import minimize
|
14 |
+
from scipy.special import expit
|
15 |
+
|
16 |
+
# Constants
|
17 |
+
COLORS = [
|
18 |
+
"#1B7FFF",
|
19 |
+
"#F07D1A",
|
20 |
+
"#BA24C7",
|
21 |
+
"#FE42C7",
|
22 |
+
"#0D4B7C",
|
23 |
+
"#0EAC96",
|
24 |
+
"#AA7CFF",
|
25 |
+
"#B50550",
|
26 |
+
"#009EEB",
|
27 |
+
"#220B55",
|
28 |
+
"#7B3301",
|
29 |
+
]
|
30 |
+
NAME_MAPPING = {
|
31 |
+
"gemini_2f": "Gemini 2.0 (Exp)",
|
32 |
+
"diva_3_8b": "DiVA Llama 3 8B",
|
33 |
+
"qwen2": "Qwen 2 Audio",
|
34 |
+
"pipe_l3.0": "Pipelined Llama 3 8B",
|
35 |
+
"gemini_1.5f": "Gemini 1.5 Flash",
|
36 |
+
"gpt4o": "GPT-4o",
|
37 |
+
"gemini_1.5p": "Gemini 1.5 Pro",
|
38 |
+
"typhoon_audio": "Typhoon Audio",
|
39 |
+
}
|
40 |
+
|
41 |
+
def get_aesthetic_timestamp():
|
42 |
+
"""
|
43 |
+
Returns a beautifully formatted timestamp in the format:
|
44 |
+
'Tuesday, December 10th, 2024 at 3:45 PM'
|
45 |
+
"""
|
46 |
+
# Get timezone object for PST
|
47 |
+
pst = ZoneInfo("America/Los_Angeles")
|
48 |
+
|
49 |
+
# Get current time in PST
|
50 |
+
now = datetime.now(pst)
|
51 |
+
|
52 |
+
# Add suffix to day number (1st, 2nd, 3rd, etc.)
|
53 |
+
day = now.day
|
54 |
+
if 4 <= day <= 20 or 24 <= day <= 30:
|
55 |
+
suffix = "th"
|
56 |
+
else:
|
57 |
+
suffix = ["st", "nd", "rd"][day % 10 - 1]
|
58 |
+
return now.strftime(f"%A, %B {day}{suffix}, %Y at %-I:%M %p")
|
59 |
+
|
60 |
+
|
61 |
+
def bootstrap_ci(data, n_bootstrap=10000, ci=95):
|
62 |
+
"""Calculate bootstrap confidence intervals."""
|
63 |
+
bootstrap_samples = []
|
64 |
+
for _ in range(n_bootstrap):
|
65 |
+
bootstrap_samples.append(np.mean(random.choices(data, k=len(data))))
|
66 |
+
lower_bound = np.percentile(bootstrap_samples, (100 - ci) / 2)
|
67 |
+
upper_bound = np.percentile(bootstrap_samples, 100 - (100 - ci) / 2)
|
68 |
+
return lower_bound, upper_bound
|
69 |
+
|
70 |
+
|
71 |
+
def calculate_win_rates(json_data):
|
72 |
+
"""Calculate win rates from JSON data."""
|
73 |
+
data = json.loads(json_data)
|
74 |
+
|
75 |
+
model_wins = defaultdict(int)
|
76 |
+
total_matches = defaultdict(int)
|
77 |
+
total_votes = 0
|
78 |
+
|
79 |
+
for value in data["_default"].values():
|
80 |
+
total_votes += 1
|
81 |
+
if value["outcome"] == 0:
|
82 |
+
model_wins[value["model_a"]] += 1
|
83 |
+
elif value["outcome"] == 1:
|
84 |
+
model_wins[value["model_b"]] += 1
|
85 |
+
elif value["outcome"] == 0.5:
|
86 |
+
model_wins[value["model_a"]] += 0.5
|
87 |
+
model_wins[value["model_b"]] += 0.5
|
88 |
+
total_matches[value["model_a"]] += 1
|
89 |
+
total_matches[value["model_b"]] += 1
|
90 |
+
|
91 |
+
per_model_wins = {}
|
92 |
+
for model, wins in model_wins.items():
|
93 |
+
win_rate = wins / total_matches[model]
|
94 |
+
wins_data = [1] * int(wins) + [0] * int(total_matches[model] - wins)
|
95 |
+
if int(wins) != wins:
|
96 |
+
wins_data += [0.5]
|
97 |
+
lower, upper = bootstrap_ci(wins_data)
|
98 |
+
per_model_wins[model] = {
|
99 |
+
"model": model,
|
100 |
+
"win_rate": win_rate,
|
101 |
+
"95_lower": (win_rate - lower),
|
102 |
+
"95_upper": (upper - win_rate),
|
103 |
+
}
|
104 |
+
df = pd.DataFrame.from_dict(per_model_wins).T
|
105 |
+
|
106 |
+
return df, total_votes
|
107 |
+
|
108 |
+
|
109 |
+
def create_win_rate_plot(wins_df):
|
110 |
+
"""Create win rate plot using Plotly."""
|
111 |
+
wins_df["Source"] = wins_df["Source"].astype(str)
|
112 |
+
wins_df = wins_df.sort_values(by=["Source", "win_rate"], ascending=False)
|
113 |
+
wins_df["model"] = wins_df["model"].apply(lambda x: NAME_MAPPING.get(x, x))
|
114 |
+
|
115 |
+
fig = px.bar(
|
116 |
+
wins_df,
|
117 |
+
x="model",
|
118 |
+
y="win_rate",
|
119 |
+
error_y="95_upper",
|
120 |
+
error_y_minus="95_lower",
|
121 |
+
color="model",
|
122 |
+
color_discrete_sequence=COLORS,
|
123 |
+
animation_group="model",
|
124 |
+
animation_frame="Source",
|
125 |
+
)
|
126 |
+
|
127 |
+
fig.update_traces(
|
128 |
+
hovertemplate="<b>%{x}</b><br>" + "Win Rate: %{y}" + "<extra></extra>",
|
129 |
+
)
|
130 |
+
|
131 |
+
fig.update_layout(
|
132 |
+
autosize=True,
|
133 |
+
showlegend=False,
|
134 |
+
plot_bgcolor="white",
|
135 |
+
title={
|
136 |
+
"text": "Talk Arena Live Win Rates<br>with 95% Confidence Intervals",
|
137 |
+
"y": 0.95,
|
138 |
+
"x": 0.5,
|
139 |
+
"xanchor": "center",
|
140 |
+
"yanchor": "top",
|
141 |
+
},
|
142 |
+
xaxis_title="Model",
|
143 |
+
yaxis_title="Win Rate (%)",
|
144 |
+
bargap=0.2,
|
145 |
+
yaxis=dict(
|
146 |
+
tickformat=".0%", tickmode="auto", range=[0, 1.01], gridcolor="#C9CCD1", griddash="dash", gridwidth=2
|
147 |
+
),
|
148 |
+
legend=dict(
|
149 |
+
orientation="h", # Make legend horizontal
|
150 |
+
yanchor="bottom",
|
151 |
+
y=-0.5, # Position below plot
|
152 |
+
xanchor="center",
|
153 |
+
x=0.5, # Center horizontally
|
154 |
+
bgcolor="rgba(255, 255, 255, 0.8)",
|
155 |
+
bordercolor="#C9CCD1",
|
156 |
+
borderwidth=1,
|
157 |
+
),
|
158 |
+
margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
|
159 |
+
hoverlabel=dict(bgcolor="white", font_size=14, bordercolor="gray"),
|
160 |
+
)
|
161 |
+
|
162 |
+
fig.update_xaxes(showgrid=False)
|
163 |
+
|
164 |
+
return fig
|
165 |
+
|
166 |
+
|
167 |
+
# Bradley-Terry Model Functions
|
168 |
+
def load_live_votes(json_str: str) -> pd.DataFrame:
|
169 |
+
"""Load and preprocess live votes data from JSON string."""
|
170 |
+
data = json.loads(json_str)
|
171 |
+
df = pd.DataFrame.from_dict(data["_default"], orient="index")
|
172 |
+
df["winner"] = df["outcome"].map({1: "model_b", 0: "model_a", 0.5: "tie"})
|
173 |
+
return df
|
174 |
+
|
175 |
+
|
176 |
+
def preprocess_for_bt(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray, List[str], np.ndarray]:
|
177 |
+
"""Preprocess data for Bradley-Terry model fitting."""
|
178 |
+
all_models = pd.concat([df["model_a"], df["model_b"]]).unique()
|
179 |
+
model_to_idx = {model: idx for idx, model in enumerate(all_models)}
|
180 |
+
|
181 |
+
matchups = np.array([[model_to_idx[row.model_a], model_to_idx[row.model_b]] for _, row in df.iterrows()])
|
182 |
+
|
183 |
+
outcomes = np.array(
|
184 |
+
[1.0 if row.winner == "model_a" else (0.5 if row.winner == "tie" else 0.0) for _, row in df.iterrows()]
|
185 |
+
)
|
186 |
+
|
187 |
+
unique_matches = np.column_stack([matchups, outcomes])
|
188 |
+
unique_matches, weights = np.unique(unique_matches, return_counts=True, axis=0)
|
189 |
+
|
190 |
+
return (unique_matches[:, :2].astype(np.int32), unique_matches[:, 2], list(all_models), weights.astype(np.float64))
|
191 |
+
|
192 |
+
|
193 |
+
def bt_loss_and_grad(
|
194 |
+
ratings: np.ndarray, matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, alpha: float = 1.0
|
195 |
+
) -> Tuple[float, np.ndarray]:
|
196 |
+
"""Compute Bradley-Terry loss and gradient."""
|
197 |
+
matchup_ratings = ratings[matchups]
|
198 |
+
logits = alpha * (matchup_ratings[:, 0] - matchup_ratings[:, 1])
|
199 |
+
probs = expit(logits)
|
200 |
+
|
201 |
+
loss = -((np.log(probs) * outcomes + np.log(1.0 - probs) * (1.0 - outcomes)) * weights).sum()
|
202 |
+
|
203 |
+
matchups_grads = -alpha * (outcomes - probs) * weights
|
204 |
+
model_grad = np.zeros_like(ratings)
|
205 |
+
np.add.at(model_grad, matchups[:, [0, 1]], matchups_grads[:, None] * np.array([1.0, -1.0], dtype=np.float64))
|
206 |
+
|
207 |
+
return loss, model_grad
|
208 |
+
|
209 |
+
|
210 |
+
def fit_bt(
|
211 |
+
matchups: np.ndarray, outcomes: np.ndarray, weights: np.ndarray, n_models: int, alpha: float, tol: float = 1e-6
|
212 |
+
) -> np.ndarray:
|
213 |
+
"""Fit Bradley-Terry model using L-BFGS-B optimization."""
|
214 |
+
initial_ratings = np.zeros(n_models, dtype=np.float64)
|
215 |
+
|
216 |
+
result = minimize(
|
217 |
+
fun=bt_loss_and_grad,
|
218 |
+
x0=initial_ratings,
|
219 |
+
args=(matchups, outcomes, weights, alpha),
|
220 |
+
jac=True,
|
221 |
+
method="L-BFGS-B",
|
222 |
+
options={"disp": False, "maxiter": 100, "gtol": tol},
|
223 |
+
)
|
224 |
+
|
225 |
+
return result["x"]
|
226 |
+
|
227 |
+
|
228 |
+
def scale_and_offset(
|
229 |
+
ratings: np.ndarray, models: List[str], scale: float = 400, init_rating: float = 1000
|
230 |
+
) -> np.ndarray:
|
231 |
+
"""Scale ratings to familiar Elo-like scale."""
|
232 |
+
scaled_ratings = (ratings * scale) + init_rating
|
233 |
+
return scaled_ratings
|
234 |
+
|
235 |
+
|
236 |
+
def compute_bootstrap_bt(
|
237 |
+
data: str,
|
238 |
+
num_round: int = 100,
|
239 |
+
base: float = 10.0,
|
240 |
+
scale: float = 400.0,
|
241 |
+
init_rating: float = 1000.0,
|
242 |
+
tol: float = 1e-6,
|
243 |
+
) -> pd.DataFrame:
|
244 |
+
"""Compute bootstrap Bradley-Terry ratings from live votes data."""
|
245 |
+
df = load_live_votes(data)
|
246 |
+
matchups, outcomes, models, weights = preprocess_for_bt(df)
|
247 |
+
|
248 |
+
rng = np.random.default_rng(seed=0)
|
249 |
+
total_matches = len(df)
|
250 |
+
idxs = rng.multinomial(n=total_matches, pvals=weights / weights.sum(), size=num_round)
|
251 |
+
boot_weights = idxs.astype(np.float64) / total_matches
|
252 |
+
|
253 |
+
ratings_list = []
|
254 |
+
for sample_weights in boot_weights:
|
255 |
+
ratings = fit_bt(
|
256 |
+
matchups=matchups,
|
257 |
+
outcomes=outcomes,
|
258 |
+
weights=sample_weights,
|
259 |
+
n_models=len(models),
|
260 |
+
alpha=np.log(base),
|
261 |
+
tol=tol,
|
262 |
+
)
|
263 |
+
scaled_ratings = scale_and_offset(ratings=ratings, models=models, scale=scale, init_rating=init_rating)
|
264 |
+
ratings_list.append(scaled_ratings)
|
265 |
+
|
266 |
+
df_ratings = pd.DataFrame(ratings_list, columns=models)
|
267 |
+
return df_ratings[df_ratings.median().sort_values(ascending=False).index]
|
268 |
+
|
269 |
+
|
270 |
+
def create_bt_plot(bootstrap_ratings):
|
271 |
+
"""Create Bradley-Terry ratings plot using Plotly."""
|
272 |
+
melted_bootstrap = bootstrap_ratings.melt(id_vars=["Source", "level_1"], var_name="Model", value_name="BT")
|
273 |
+
melted_bootstrap = melted_bootstrap.dropna()
|
274 |
+
melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "Model", "BT"], ascending=False)
|
275 |
+
# Pretty Names
|
276 |
+
melted_bootstrap["Model"] = melted_bootstrap["Model"].apply(lambda x: NAME_MAPPING.get(x, x))
|
277 |
+
# Compression for Client Side
|
278 |
+
melted_bootstrap["BT"] = melted_bootstrap["BT"].apply(lambda x: int(x))
|
279 |
+
min_samp = melted_bootstrap[melted_bootstrap["BT"] > 0]["BT"].min()
|
280 |
+
max_samp = melted_bootstrap["BT"].max()
|
281 |
+
idx_keep = list(range(0, len(melted_bootstrap), 10))
|
282 |
+
melted_bootstrap = melted_bootstrap.iloc[idx_keep]
|
283 |
+
melted_bootstrap = melted_bootstrap.sort_values(by=["Source", "BT"], ascending=False)
|
284 |
+
fig = px.violin(
|
285 |
+
melted_bootstrap,
|
286 |
+
x="Model",
|
287 |
+
y="BT",
|
288 |
+
color="Model",
|
289 |
+
animation_group="Model",
|
290 |
+
animation_frame="Source",
|
291 |
+
color_discrete_sequence=COLORS,
|
292 |
+
)
|
293 |
+
|
294 |
+
fig.update_layout(
|
295 |
+
autosize=True,
|
296 |
+
showlegend=False,
|
297 |
+
plot_bgcolor="white",
|
298 |
+
title={
|
299 |
+
"text": "Talk Arena Live Bradley-Terry Ratings<br>with Bootstrapped Variance",
|
300 |
+
"y": 0.9,
|
301 |
+
"x": 0.5,
|
302 |
+
"xanchor": "center",
|
303 |
+
"yanchor": "top",
|
304 |
+
},
|
305 |
+
xaxis_title="Model",
|
306 |
+
yaxis_title="Rating",
|
307 |
+
yaxis=dict(gridcolor="#C9CCD1", range=[min_samp - 10, max_samp + 10], griddash="dash"),
|
308 |
+
legend=dict(
|
309 |
+
orientation="h", # Make legend horizontal
|
310 |
+
yanchor="bottom",
|
311 |
+
y=-0.5, # Position below plot
|
312 |
+
xanchor="center",
|
313 |
+
x=0.5, # Center horizontally
|
314 |
+
bgcolor="rgba(255, 255, 255, 0.8)",
|
315 |
+
bordercolor="#C9CCD1",
|
316 |
+
borderwidth=1,
|
317 |
+
),
|
318 |
+
margin=dict(l=10, r=10, t=0, b=10), # Balanced margins
|
319 |
+
)
|
320 |
+
|
321 |
+
fig.update_xaxes(showgrid=False)
|
322 |
+
fig.update_yaxes(showgrid=True, gridwidth=2)
|
323 |
+
|
324 |
+
return fig
|
talk_arena/viz/server.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import textwrap
|
4 |
+
import time
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import Optional
|
7 |
+
from zoneinfo import ZoneInfo
|
8 |
+
|
9 |
+
import plotly.io as pio
|
10 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
11 |
+
from fastapi import FastAPI, HTTPException, Response
|
12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
13 |
+
|
14 |
+
from talk_arena.viz.core import *
|
15 |
+
|
16 |
+
app = FastAPI(title="Talk Arena API", description="API for Talk Arena leaderboard and statistics", version="0.0.1")
|
17 |
+
|
18 |
+
# Add CORS middleware
|
19 |
+
app.add_middleware(
|
20 |
+
CORSMiddleware,
|
21 |
+
allow_origins=["*"], # In production, replace with specific origins
|
22 |
+
allow_credentials=True,
|
23 |
+
allow_methods=["*"],
|
24 |
+
allow_headers=["*"],
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
# Global variables to store the plots and update time
|
29 |
+
class GlobalState:
|
30 |
+
WR_PLOT = None
|
31 |
+
BT_PLOT = None
|
32 |
+
UPDATE_TIME = None
|
33 |
+
LAST_PROCESSED = None
|
34 |
+
MIN_UPDATE_INTERVAL = 60 # Minimum seconds between updates
|
35 |
+
|
36 |
+
|
37 |
+
state = GlobalState()
|
38 |
+
|
39 |
+
|
40 |
+
def process_and_visualize(force: bool = False):
|
41 |
+
"""Process data and create visualizations"""
|
42 |
+
global state
|
43 |
+
current_time = datetime.now(ZoneInfo("America/Los_Angeles"))
|
44 |
+
|
45 |
+
# Check if enough time has passed since last update
|
46 |
+
if not force and state.LAST_PROCESSED:
|
47 |
+
time_diff = (current_time - state.LAST_PROCESSED).total_seconds()
|
48 |
+
if time_diff < state.MIN_UPDATE_INTERVAL:
|
49 |
+
logger.info(f"Skipping update - only {time_diff:.1f} seconds since last update")
|
50 |
+
return
|
51 |
+
|
52 |
+
state.LAST_PROCESSED = current_time
|
53 |
+
if state.WR_PLOT is not None and state.BT_PLOT is not None and not force:
|
54 |
+
return
|
55 |
+
|
56 |
+
try:
|
57 |
+
# Read JSON data
|
58 |
+
pub_json_data = open("/home/wheld3/talk-arena/live_votes.json", "r").read()
|
59 |
+
prolific_json_data = open("/home/wheld3/talk-arena/prolific_votes.json", "r").read()
|
60 |
+
merged_json_data = json.dumps(
|
61 |
+
{"_default": {**json.loads(pub_json_data)["_default"], **json.loads(prolific_json_data)["_default"]}}
|
62 |
+
)
|
63 |
+
|
64 |
+
# Calculate win rates and create plots
|
65 |
+
pub_win_rates, pub_votes = calculate_win_rates(pub_json_data)
|
66 |
+
pro_win_rates, pro_votes = calculate_win_rates(prolific_json_data)
|
67 |
+
total_win_rates, total_votes = calculate_win_rates(merged_json_data)
|
68 |
+
|
69 |
+
# Process win rates
|
70 |
+
all_models = total_win_rates["model"].unique()
|
71 |
+
pro_models = pro_win_rates["model"].unique()
|
72 |
+
for model in all_models:
|
73 |
+
if model not in pro_models:
|
74 |
+
new_index = len(pro_win_rates)
|
75 |
+
pro_win_rates.loc[new_index] = [model, -0.1, -0.1, -0.2]
|
76 |
+
|
77 |
+
win_rates = (
|
78 |
+
pd.concat([pub_win_rates, pro_win_rates, total_win_rates], keys=["Public", "Prolific", "Total"])
|
79 |
+
.reset_index()
|
80 |
+
.rename(columns={"level_0": "Source"})
|
81 |
+
)
|
82 |
+
|
83 |
+
state.WR_PLOT = create_win_rate_plot(win_rates)
|
84 |
+
|
85 |
+
# Calculate Bradley-Terry ratings
|
86 |
+
pub_bootstrap_ratings = compute_bootstrap_bt(pub_json_data, num_round=10000)
|
87 |
+
pro_bootstrap_ratings = compute_bootstrap_bt(prolific_json_data, num_round=10000)
|
88 |
+
total_bootstrap_ratings = compute_bootstrap_bt(merged_json_data, num_round=10000)
|
89 |
+
|
90 |
+
for model in all_models:
|
91 |
+
if model not in pro_models:
|
92 |
+
pro_bootstrap_ratings[model] = pro_bootstrap_ratings["diva_3_8b"] * -1
|
93 |
+
|
94 |
+
bootstrap_ratings = (
|
95 |
+
pd.concat(
|
96 |
+
[pub_bootstrap_ratings, pro_bootstrap_ratings, total_bootstrap_ratings],
|
97 |
+
keys=["Public", "Prolific", "Total"],
|
98 |
+
)
|
99 |
+
.reset_index()
|
100 |
+
.rename(columns={"level_0": "Source"})
|
101 |
+
)
|
102 |
+
|
103 |
+
state.BT_PLOT = create_bt_plot(bootstrap_ratings)
|
104 |
+
|
105 |
+
# Update timestamp and vote counts
|
106 |
+
state.UPDATE_TIME = {
|
107 |
+
"timestamp": get_aesthetic_timestamp(),
|
108 |
+
"total_votes": total_votes,
|
109 |
+
"public_votes": pub_votes,
|
110 |
+
"prolific_votes": pro_votes,
|
111 |
+
}
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
raise HTTPException(status_code=500, detail=f"Error processing data: {str(e)}")
|
115 |
+
|
116 |
+
|
117 |
+
# Set up logging
|
118 |
+
import logging
|
119 |
+
|
120 |
+
|
121 |
+
logging.basicConfig(level=logging.INFO)
|
122 |
+
logger = logging.getLogger(__name__)
|
123 |
+
|
124 |
+
# Global scheduler instance
|
125 |
+
scheduler = None
|
126 |
+
|
127 |
+
|
128 |
+
def update_job():
|
129 |
+
"""Wrapper for the update job with error handling and logging"""
|
130 |
+
try:
|
131 |
+
logger.info("Starting scheduled update...")
|
132 |
+
process_and_visualize(force=True)
|
133 |
+
logger.info("Scheduled update completed successfully")
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Error in scheduled update: {str(e)}", exc_info=True)
|
136 |
+
|
137 |
+
|
138 |
+
@app.on_event("startup")
|
139 |
+
async def startup_event():
|
140 |
+
"""Initialize data and start scheduler"""
|
141 |
+
global scheduler
|
142 |
+
|
143 |
+
try:
|
144 |
+
logger.info("Starting initial data processing...")
|
145 |
+
process_and_visualize(force=True)
|
146 |
+
logger.info("Initial data processing completed")
|
147 |
+
|
148 |
+
# Clear any existing schedulers
|
149 |
+
if scheduler:
|
150 |
+
scheduler.shutdown(wait=False)
|
151 |
+
|
152 |
+
# Initialize and start the scheduler
|
153 |
+
scheduler = BackgroundScheduler(
|
154 |
+
timezone=ZoneInfo("America/Los_Angeles"), job_defaults={"coalesce": True, "max_instances": 1}
|
155 |
+
)
|
156 |
+
|
157 |
+
# Add the job with proper error handling
|
158 |
+
scheduler.add_job(
|
159 |
+
func=update_job, # Use the wrapper function
|
160 |
+
trigger="interval",
|
161 |
+
seconds=300,
|
162 |
+
id="update_visualizations",
|
163 |
+
name="Update Visualizations",
|
164 |
+
misfire_grace_time=60,
|
165 |
+
)
|
166 |
+
|
167 |
+
scheduler.start()
|
168 |
+
logger.info("Scheduler started successfully")
|
169 |
+
|
170 |
+
# Verify the job was added
|
171 |
+
jobs = scheduler.get_jobs()
|
172 |
+
logger.info(f"Current scheduled jobs: {[job.name for job in jobs]}")
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"Error during startup: {str(e)}", exc_info=True)
|
176 |
+
raise
|
177 |
+
|
178 |
+
|
179 |
+
@app.on_event("shutdown")
|
180 |
+
async def shutdown_event():
|
181 |
+
"""Properly shutdown the scheduler when the app stops"""
|
182 |
+
global scheduler
|
183 |
+
try:
|
184 |
+
if scheduler:
|
185 |
+
logger.info("Shutting down scheduler...")
|
186 |
+
scheduler.shutdown(wait=False)
|
187 |
+
logger.info("Scheduler shutdown complete")
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Error during scheduler shutdown: {str(e)}", exc_info=True)
|
190 |
+
|
191 |
+
|
192 |
+
# Add an endpoint to manually trigger an update
|
193 |
+
@app.post("/api/trigger-update")
|
194 |
+
async def trigger_update():
|
195 |
+
"""Manually trigger a data update"""
|
196 |
+
try:
|
197 |
+
logger.info("Manual update triggered")
|
198 |
+
process_and_visualize(force=True)
|
199 |
+
logger.info("Manual update completed")
|
200 |
+
return {"status": "success", "message": "Update completed"}
|
201 |
+
except Exception as e:
|
202 |
+
logger.error(f"Error in manual update: {str(e)}", exc_info=True)
|
203 |
+
raise HTTPException(status_code=500, detail=str(e))
|
204 |
+
|
205 |
+
|
206 |
+
def generate_etag(data: dict) -> str:
|
207 |
+
"""Generate an ETag for the given data"""
|
208 |
+
# Convert data to a consistent string representation and hash it
|
209 |
+
data_str = json.dumps(data, sort_keys=True)
|
210 |
+
return hashlib.md5(data_str.encode()).hexdigest()
|
211 |
+
|
212 |
+
|
213 |
+
@app.get("/api/win-rate-plot")
|
214 |
+
async def get_wr_plot(response: Response):
|
215 |
+
"""Get the win rate plot data"""
|
216 |
+
if state.WR_PLOT is None:
|
217 |
+
raise HTTPException(status_code=503, detail="Plot data not yet available")
|
218 |
+
|
219 |
+
plot_json = json.loads(pio.to_json(state.WR_PLOT))
|
220 |
+
|
221 |
+
# Customize animation settings
|
222 |
+
for step in plot_json["layout"]["sliders"][0]["steps"]:
|
223 |
+
step["args"][1]["frame"]["duration"] = 500
|
224 |
+
step["args"][1]["transition"]["duration"] = 500
|
225 |
+
|
226 |
+
plot_json["layout"]["updatemenus"] = []
|
227 |
+
plot_json["layout"]["sliders"][0]["len"] = 0.8
|
228 |
+
plot_json["layout"]["sliders"][0]["pad"] = {}
|
229 |
+
|
230 |
+
# Generate ETag
|
231 |
+
etag = generate_etag(plot_json)
|
232 |
+
response.headers["ETag"] = etag
|
233 |
+
|
234 |
+
# Set cache control headers - cache for 4 minutes since we update every 5
|
235 |
+
response.headers["Cache-Control"] = "public, max-age=240"
|
236 |
+
|
237 |
+
return plot_json
|
238 |
+
|
239 |
+
|
240 |
+
@app.get("/api/bt-plot")
|
241 |
+
async def get_bt_plot(response: Response):
|
242 |
+
"""Get the Bradley-Terry plot data"""
|
243 |
+
if state.BT_PLOT is None:
|
244 |
+
raise HTTPException(status_code=503, detail="Plot data not yet available")
|
245 |
+
|
246 |
+
plot_json = json.loads(pio.to_json(state.BT_PLOT))
|
247 |
+
|
248 |
+
# Customize animation settings
|
249 |
+
for step in plot_json["layout"]["sliders"][0]["steps"]:
|
250 |
+
step["args"][1]["frame"]["duration"] = 500
|
251 |
+
step["args"][1]["transition"]["duration"] = 500
|
252 |
+
|
253 |
+
plot_json["layout"]["updatemenus"] = []
|
254 |
+
plot_json["layout"]["sliders"][0]["len"] = 0.8
|
255 |
+
plot_json["layout"]["sliders"][0]["pad"] = {}
|
256 |
+
|
257 |
+
# Generate ETag
|
258 |
+
etag = generate_etag(plot_json)
|
259 |
+
response.headers["ETag"] = etag
|
260 |
+
|
261 |
+
# Set cache control headers - cache for 4 minutes since we update every 5
|
262 |
+
response.headers["Cache-Control"] = "public, max-age=240"
|
263 |
+
|
264 |
+
return plot_json
|
265 |
+
|
266 |
+
|
267 |
+
@app.get("/api/update-time")
|
268 |
+
async def get_update_time(response: Response):
|
269 |
+
"""Get the last update time and vote counts"""
|
270 |
+
if state.UPDATE_TIME is None:
|
271 |
+
raise HTTPException(status_code=503, detail="Update time not yet available")
|
272 |
+
|
273 |
+
# Generate ETag
|
274 |
+
etag = generate_etag(state.UPDATE_TIME)
|
275 |
+
response.headers["ETag"] = etag
|
276 |
+
|
277 |
+
# Set cache control headers - cache for 4 minutes
|
278 |
+
response.headers["Cache-Control"] = "public, max-age=240"
|
279 |
+
|
280 |
+
return state.UPDATE_TIME
|
281 |
+
|
282 |
+
|
283 |
+
@app.get("/api/health")
|
284 |
+
async def health_check(response: Response):
|
285 |
+
"""Enhanced health check endpoint with scheduler status"""
|
286 |
+
global scheduler
|
287 |
+
|
288 |
+
scheduler_status = "not_running"
|
289 |
+
next_run = None
|
290 |
+
last_run = state.UPDATE_TIME["timestamp"] if state.UPDATE_TIME else None
|
291 |
+
|
292 |
+
if scheduler:
|
293 |
+
try:
|
294 |
+
jobs = scheduler.get_jobs()
|
295 |
+
if jobs:
|
296 |
+
scheduler_status = "running"
|
297 |
+
next_run = jobs[0].next_run_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
298 |
+
except Exception as e:
|
299 |
+
logger.error(f"Error checking scheduler status: {str(e)}")
|
300 |
+
scheduler_status = f"error: {str(e)}"
|
301 |
+
|
302 |
+
health_data = {
|
303 |
+
"status": "healthy",
|
304 |
+
"scheduler_status": scheduler_status,
|
305 |
+
"last_update": last_run,
|
306 |
+
"next_scheduled_update": next_run,
|
307 |
+
"current_time": datetime.now(ZoneInfo("America/Los_Angeles")).strftime("%Y-%m-%d %H:%M:%S %Z"),
|
308 |
+
}
|
309 |
+
|
310 |
+
# Generate ETag
|
311 |
+
etag = generate_etag(health_data)
|
312 |
+
response.headers["ETag"] = etag
|
313 |
+
|
314 |
+
# Set cache control headers - short cache time for health check
|
315 |
+
response.headers["Cache-Control"] = "public, max-age=30"
|
316 |
+
|
317 |
+
return health_data
|
318 |
+
|
319 |
+
|
320 |
+
if __name__ == "__main__":
|
321 |
+
import uvicorn
|
322 |
+
|
323 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|