Spaces:
Running
Zg/add head to head results (#19)
Browse files* Add OpenAI python SDK to dependencies
* Fix Anthropic clean API Error message.
* Update constants and custom types associated with TTS providers to include OpenAI
* Add OpenAI integration
* Update logic for selecting providers, add OpenAI tts to UI
* Fix typo in openai_api.py
* Update docstrings in openai_api.py
* Update leaderboard results query to include OpenAI results
* Add citation
* Adjust padding in UI components
* Adjust padding in UI components in citation
* Add transitive dependency override for sounddevice in pyproject.toml
* remove sounddevice
* Add warning toast for custom text inputs
* Improve leaderboard results query to account for zero records, and update to only include relevant comparison types for each provider.
* Fix database package imports and add head-to-head comparison queries
* Add utils for fetching head-to-head comparison data
* Updates UI to include head-to-head comparison tables, updates leaderboard data fetching in UI to include comparison data
---------
Co-authored-by: twitchard <[email protected]>
- src/constants.py +1 -9
- src/database/__init__.py +6 -1
- src/database/crud.py +164 -7
- src/frontend.py +58 -16
- src/scripts/init_db.py +1 -1
- src/scripts/test_db.py +1 -1
- src/utils.py +146 -30
@@ -10,7 +10,6 @@ from typing import Dict, List
|
|
10 |
# Third-Party Library Imports
|
11 |
from src.custom_types import (
|
12 |
ComparisonType,
|
13 |
-
LeaderboardEntry,
|
14 |
OptionKey,
|
15 |
OptionLabel,
|
16 |
TTSProviderName,
|
@@ -26,7 +25,7 @@ HUME_AI: TTSProviderName = "Hume AI"
|
|
26 |
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
27 |
OPENAI: TTSProviderName = "OpenAI"
|
28 |
|
29 |
-
TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "
|
30 |
TTS_PROVIDER_LINKS = {
|
31 |
"Hume AI": {
|
32 |
"provider_link": "https://hume.ai/",
|
@@ -169,10 +168,3 @@ META_TAGS: List[Dict[str, str]] = [
|
|
169 |
'content': '/static/arena-opengraph-logo.png'
|
170 |
}
|
171 |
]
|
172 |
-
|
173 |
-
# Reflects and empty leaderboard state
|
174 |
-
DEFAULT_LEADERBOARD: List[LeaderboardEntry] = [
|
175 |
-
LeaderboardEntry("1", "", "", "0%", "0"),
|
176 |
-
LeaderboardEntry("2", "", "", "0%", "0"),
|
177 |
-
LeaderboardEntry("3", "", "", "0%", "0"),
|
178 |
-
]
|
|
|
10 |
# Third-Party Library Imports
|
11 |
from src.custom_types import (
|
12 |
ComparisonType,
|
|
|
13 |
OptionKey,
|
14 |
OptionLabel,
|
15 |
TTSProviderName,
|
|
|
25 |
ELEVENLABS: TTSProviderName = "ElevenLabs"
|
26 |
OPENAI: TTSProviderName = "OpenAI"
|
27 |
|
28 |
+
TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "OpenAI", "ElevenLabs"]
|
29 |
TTS_PROVIDER_LINKS = {
|
30 |
"Hume AI": {
|
31 |
"provider_link": "https://hume.ai/",
|
|
|
168 |
'content': '/static/arena-opengraph-logo.png'
|
169 |
}
|
170 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,10 +1,15 @@
|
|
1 |
-
from .crud import create_vote
|
2 |
from .database import AsyncDBSessionMaker, Base, engine, init_db
|
|
|
3 |
|
4 |
__all__ = [
|
5 |
"AsyncDBSessionMaker",
|
6 |
"Base",
|
|
|
7 |
"create_vote",
|
8 |
"engine",
|
|
|
|
|
|
|
9 |
"init_db",
|
10 |
]
|
|
|
1 |
+
from .crud import create_vote, get_head_to_head_battle_stats, get_head_to_head_win_rate_stats, get_leaderboard_stats
|
2 |
from .database import AsyncDBSessionMaker, Base, engine, init_db
|
3 |
+
from .models import VoteResult
|
4 |
|
5 |
__all__ = [
|
6 |
"AsyncDBSessionMaker",
|
7 |
"Base",
|
8 |
+
"VoteResult",
|
9 |
"create_vote",
|
10 |
"engine",
|
11 |
+
"get_head_to_head_battle_stats",
|
12 |
+
"get_head_to_head_win_rate_stats",
|
13 |
+
"get_leaderboard_stats",
|
14 |
"init_db",
|
15 |
]
|
@@ -5,6 +5,9 @@ This module defines the operations for the Expressive TTS Arena project's databa
|
|
5 |
Since vote records are never updated or deleted, only functions to create and read votes are provided.
|
6 |
"""
|
7 |
|
|
|
|
|
|
|
8 |
# Third-Party Library Imports
|
9 |
from sqlalchemy import text
|
10 |
from sqlalchemy.exc import SQLAlchemyError
|
@@ -12,7 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|
12 |
|
13 |
# Local Application Imports
|
14 |
from src.config import logger
|
15 |
-
from src.constants import DEFAULT_LEADERBOARD
|
16 |
from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
|
17 |
from src.database.models import VoteResult
|
18 |
|
@@ -83,6 +85,12 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
|
|
83 |
LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
|
84 |
provider name, model name, win rate, and total votes.
|
85 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
try:
|
87 |
query = text(
|
88 |
"""
|
@@ -137,6 +145,10 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
|
|
137 |
result = await db.execute(query)
|
138 |
rows = result.fetchall()
|
139 |
|
|
|
|
|
|
|
|
|
140 |
# Format the data for the leaderboard
|
141 |
leaderboard_data = []
|
142 |
for i, row in enumerate(rows, 1):
|
@@ -150,16 +162,161 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
|
|
150 |
)
|
151 |
leaderboard_data.append(leaderboard_entry)
|
152 |
|
153 |
-
# If no data was found, return default entries
|
154 |
-
if not leaderboard_data:
|
155 |
-
return DEFAULT_LEADERBOARD
|
156 |
-
|
157 |
return leaderboard_data
|
158 |
|
159 |
except SQLAlchemyError as e:
|
160 |
logger.error(f"Database error while fetching leaderboard stats: {e}")
|
161 |
-
return
|
162 |
except Exception as e:
|
163 |
logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
|
164 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
Since vote records are never updated or deleted, only functions to create and read votes are provided.
|
6 |
"""
|
7 |
|
8 |
+
# Standard Library Imports
|
9 |
+
from typing import List
|
10 |
+
|
11 |
# Third-Party Library Imports
|
12 |
from sqlalchemy import text
|
13 |
from sqlalchemy.exc import SQLAlchemyError
|
|
|
15 |
|
16 |
# Local Application Imports
|
17 |
from src.config import logger
|
|
|
18 |
from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
|
19 |
from src.database.models import VoteResult
|
20 |
|
|
|
85 |
LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
|
86 |
provider name, model name, win rate, and total votes.
|
87 |
"""
|
88 |
+
default_leaderboard = [
|
89 |
+
LeaderboardEntry("1", "", "", "0%", "0"),
|
90 |
+
LeaderboardEntry("2", "", "", "0%", "0"),
|
91 |
+
LeaderboardEntry("3", "", "", "0%", "0"),
|
92 |
+
]
|
93 |
+
|
94 |
try:
|
95 |
query = text(
|
96 |
"""
|
|
|
145 |
result = await db.execute(query)
|
146 |
rows = result.fetchall()
|
147 |
|
148 |
+
# If no rows, return default
|
149 |
+
if not rows:
|
150 |
+
return default_leaderboard
|
151 |
+
|
152 |
# Format the data for the leaderboard
|
153 |
leaderboard_data = []
|
154 |
for i, row in enumerate(rows, 1):
|
|
|
162 |
)
|
163 |
leaderboard_data.append(leaderboard_entry)
|
164 |
|
|
|
|
|
|
|
|
|
165 |
return leaderboard_data
|
166 |
|
167 |
except SQLAlchemyError as e:
|
168 |
logger.error(f"Database error while fetching leaderboard stats: {e}")
|
169 |
+
return default_leaderboard
|
170 |
except Exception as e:
|
171 |
logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
|
172 |
+
return default_leaderboard
|
173 |
+
|
174 |
+
|
175 |
+
async def get_head_to_head_battle_stats(db: AsyncSession) -> List[List[str]]:
|
176 |
+
"""
|
177 |
+
Fetches the total number of voting results for each comparison type (excluding "Hume AI - Hume AI").
|
178 |
+
|
179 |
+
Args:
|
180 |
+
db (AsyncSession): The SQLAlchemy async database session.
|
181 |
+
|
182 |
+
Returns:
|
183 |
+
List[List[str]]: A list of lists, where each inner list contains the comparison type and the count.
|
184 |
+
"""
|
185 |
+
default_counts = [
|
186 |
+
["Hume AI - OpenAI", "0"],
|
187 |
+
["Hume AI - ElevenLabs", "0"],
|
188 |
+
["OpenAI - ElevenLabs", "0"],
|
189 |
+
]
|
190 |
|
191 |
+
try:
|
192 |
+
query = text(
|
193 |
+
"""
|
194 |
+
SELECT
|
195 |
+
comparison_type,
|
196 |
+
COUNT(*) as total
|
197 |
+
FROM vote_results
|
198 |
+
WHERE comparison_type != 'Hume AI - Hume AI'
|
199 |
+
GROUP BY comparison_type
|
200 |
+
ORDER BY comparison_type;
|
201 |
+
"""
|
202 |
+
)
|
203 |
+
|
204 |
+
result = await db.execute(query)
|
205 |
+
rows = result.fetchall()
|
206 |
+
|
207 |
+
# If no rows, return default
|
208 |
+
if not rows:
|
209 |
+
return default_counts
|
210 |
+
|
211 |
+
# Format the results
|
212 |
+
formatted_results = []
|
213 |
+
for row in rows:
|
214 |
+
comparison_type, count = row
|
215 |
+
formatted_results.append([comparison_type, str(count)])
|
216 |
+
|
217 |
+
# Make sure all expected comparison types are included
|
218 |
+
expected_types = {"Hume AI - OpenAI", "Hume AI - ElevenLabs", "OpenAI - ElevenLabs"}
|
219 |
+
found_types = {row[0] for row in formatted_results}
|
220 |
+
|
221 |
+
# Add missing types with zero counts
|
222 |
+
for type_name in expected_types - found_types:
|
223 |
+
formatted_results.append([type_name, "0"])
|
224 |
+
|
225 |
+
# Sort the results by comparison type
|
226 |
+
formatted_results.sort(key=lambda x: x[0])
|
227 |
+
|
228 |
+
return formatted_results
|
229 |
+
|
230 |
+
except SQLAlchemyError as e:
|
231 |
+
logger.error(f"Database error while fetching comparison counts: {e}")
|
232 |
+
return default_counts
|
233 |
+
except Exception as e:
|
234 |
+
logger.error(f"Unexpected error while fetching comparison counts: {e}")
|
235 |
+
return default_counts
|
236 |
+
|
237 |
+
|
238 |
+
async def get_head_to_head_win_rate_stats(db: AsyncSession) -> List[List[str]]:
|
239 |
+
"""
|
240 |
+
Calculates the win rate for each provider against the other in head-to-head comparisons.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
db (AsyncSession): The SQLAlchemy async database session.
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
List[List[str]]: A list of lists, where each inner list contains:
|
247 |
+
- The comparison type
|
248 |
+
- The win rate of the first provider (the one named first in the comparison type)
|
249 |
+
- The win rate of the second provider (the one named second in the comparison type)
|
250 |
+
"""
|
251 |
+
default_win_rates = [
|
252 |
+
["Hume AI - OpenAI", "0%", "0%"],
|
253 |
+
["Hume AI - ElevenLabs", "0%", "0%"],
|
254 |
+
["OpenAI - ElevenLabs", "0%", "0%"],
|
255 |
+
]
|
256 |
+
|
257 |
+
try:
|
258 |
+
query = text(
|
259 |
+
"""
|
260 |
+
SELECT
|
261 |
+
comparison_type,
|
262 |
+
CASE WHEN COUNT(*) > 0
|
263 |
+
THEN ROUND(SUM(CASE
|
264 |
+
WHEN comparison_type = 'Hume AI - OpenAI' AND winning_provider = 'Hume AI' THEN 1
|
265 |
+
WHEN comparison_type = 'Hume AI - ElevenLabs' AND winning_provider = 'Hume AI' THEN 1
|
266 |
+
WHEN comparison_type = 'OpenAI - ElevenLabs' AND winning_provider = 'OpenAI' THEN 1
|
267 |
+
ELSE 0
|
268 |
+
END) * 100.0 / COUNT(*), 2)
|
269 |
+
ELSE 0
|
270 |
+
END as first_provider_win_rate,
|
271 |
+
CASE WHEN COUNT(*) > 0
|
272 |
+
THEN ROUND(SUM(CASE
|
273 |
+
WHEN comparison_type = 'Hume AI - OpenAI' AND winning_provider = 'OpenAI' THEN 1
|
274 |
+
WHEN comparison_type = 'Hume AI - ElevenLabs' AND winning_provider = 'ElevenLabs' THEN 1
|
275 |
+
WHEN comparison_type = 'OpenAI - ElevenLabs' AND winning_provider = 'ElevenLabs' THEN 1
|
276 |
+
ELSE 0
|
277 |
+
END) * 100.0 / COUNT(*), 2)
|
278 |
+
ELSE 0
|
279 |
+
END as second_provider_win_rate
|
280 |
+
FROM vote_results
|
281 |
+
WHERE comparison_type != 'Hume AI - Hume AI'
|
282 |
+
GROUP BY comparison_type
|
283 |
+
ORDER BY comparison_type;
|
284 |
+
"""
|
285 |
+
)
|
286 |
+
|
287 |
+
result = await db.execute(query)
|
288 |
+
rows = result.fetchall()
|
289 |
+
|
290 |
+
# If no rows, return default
|
291 |
+
if not rows:
|
292 |
+
return default_win_rates
|
293 |
+
|
294 |
+
# Format the results
|
295 |
+
formatted_results = []
|
296 |
+
for row in rows:
|
297 |
+
comparison_type, first_provider_win_rate, second_provider_win_rate = row
|
298 |
+
formatted_results.append([
|
299 |
+
comparison_type,
|
300 |
+
f"{first_provider_win_rate}%",
|
301 |
+
f"{second_provider_win_rate}%"
|
302 |
+
])
|
303 |
+
|
304 |
+
# Make sure all expected comparison types are included
|
305 |
+
expected_types = {"Hume AI - OpenAI", "Hume AI - ElevenLabs", "OpenAI - ElevenLabs"}
|
306 |
+
found_types = {row[0] for row in formatted_results}
|
307 |
+
|
308 |
+
# Add missing types with zero win rates
|
309 |
+
for type_name in expected_types - found_types:
|
310 |
+
formatted_results.append([type_name, "0%", "0%"])
|
311 |
+
|
312 |
+
# Sort the results by comparison type
|
313 |
+
formatted_results.sort(key=lambda x: x[0])
|
314 |
+
|
315 |
+
return formatted_results
|
316 |
+
|
317 |
+
except SQLAlchemyError as e:
|
318 |
+
logger.error(f"Database error while fetching provider win rates: {e}")
|
319 |
+
return default_win_rates
|
320 |
+
except Exception as e:
|
321 |
+
logger.error(f"Unexpected error while fetching provider win rates: {e}")
|
322 |
+
return default_win_rates
|
@@ -22,7 +22,7 @@ import gradio as gr
|
|
22 |
from src import constants
|
23 |
from src.config import Config, logger
|
24 |
from src.custom_types import Option, OptionMap
|
25 |
-
from src.database
|
26 |
from src.integrations import (
|
27 |
AnthropicError,
|
28 |
ElevenLabsError,
|
@@ -54,6 +54,8 @@ class Frontend:
|
|
54 |
|
55 |
# leaderboard update state
|
56 |
self._leaderboard_data: List[List[str]] = [[]]
|
|
|
|
|
57 |
self._leaderboard_cache_hash: Optional[str] = None
|
58 |
self._last_leaderboard_update_time: float = 0.0
|
59 |
self._min_refresh_interval = 30
|
@@ -77,7 +79,11 @@ class Frontend:
|
|
77 |
return False
|
78 |
|
79 |
# Fetch the latest data
|
80 |
-
|
|
|
|
|
|
|
|
|
81 |
|
82 |
# Generate a hash of the new data to check if it's changed
|
83 |
data_str = json.dumps(str(latest_leaderboard_data))
|
@@ -90,6 +96,8 @@ class Frontend:
|
|
90 |
|
91 |
# Update the cache and timestamp
|
92 |
self._leaderboard_data = latest_leaderboard_data
|
|
|
|
|
93 |
self._leaderboard_cache_hash = data_hash
|
94 |
self._last_leaderboard_update_time = current_time
|
95 |
logger.info("Leaderboard data updated successfully.")
|
@@ -330,7 +338,7 @@ class Frontend:
|
|
330 |
gr.update(value=character_description), # Update character description
|
331 |
)
|
332 |
|
333 |
-
async def _refresh_leaderboard(self, force: bool = False) -> gr.DataFrame:
|
334 |
"""
|
335 |
Asynchronously fetches and formats the latest leaderboard data.
|
336 |
|
@@ -338,17 +346,20 @@ class Frontend:
|
|
338 |
force (bool): If True, bypass time-based throttling.
|
339 |
|
340 |
Returns:
|
341 |
-
|
342 |
"""
|
343 |
data_updated = await self._update_leaderboard_data(force=force)
|
344 |
|
345 |
if not self._leaderboard_data:
|
346 |
raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
|
|
|
|
352 |
|
353 |
async def _handle_tab_select(self, evt: gr.SelectData):
|
354 |
"""
|
@@ -358,12 +369,11 @@ class Frontend:
|
|
358 |
evt (gr.SelectData): Event data containing information about the selected tab
|
359 |
|
360 |
Returns:
|
361 |
-
|
362 |
"""
|
363 |
-
# Check if the selected tab is "Leaderboard" by name
|
364 |
if evt.value == "Leaderboard":
|
365 |
return await self._refresh_leaderboard(force=False)
|
366 |
-
return gr.skip()
|
367 |
|
368 |
def _disable_ui(self) -> Tuple[
|
369 |
gr.Button,
|
@@ -909,6 +919,37 @@ class Frontend:
|
|
909 |
elem_id="leaderboard-table"
|
910 |
)
|
911 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
912 |
with gr.Accordion(label="Citation", open=False):
|
913 |
with gr.Column(variant="panel"):
|
914 |
with gr.Column(variant="panel"):
|
@@ -965,7 +1006,8 @@ class Frontend:
|
|
965 |
|
966 |
# Wrapper for the async refresh function
|
967 |
async def async_refresh_handler():
|
968 |
-
|
|
|
969 |
|
970 |
# Handler to re-enable the button after a refresh
|
971 |
def reenable_button():
|
@@ -980,14 +1022,14 @@ class Frontend:
|
|
980 |
).then(
|
981 |
fn=async_refresh_handler,
|
982 |
inputs=[],
|
983 |
-
outputs=[leaderboard_table]
|
984 |
).then(
|
985 |
fn=reenable_button,
|
986 |
inputs=[],
|
987 |
outputs=[refresh_button]
|
988 |
)
|
989 |
|
990 |
-
return leaderboard_table
|
991 |
|
992 |
async def build_gradio_interface(self) -> gr.Blocks:
|
993 |
"""
|
@@ -1004,12 +1046,12 @@ class Frontend:
|
|
1004 |
with gr.TabItem("Arena"):
|
1005 |
self._build_arena_section()
|
1006 |
with gr.TabItem("Leaderboard"):
|
1007 |
-
leaderboard_table = self._build_leaderboard_section()
|
1008 |
|
1009 |
tabs.select(
|
1010 |
fn=self._handle_tab_select,
|
1011 |
inputs=[],
|
1012 |
-
outputs=[leaderboard_table],
|
1013 |
)
|
1014 |
|
1015 |
logger.debug("Gradio interface built successfully")
|
|
|
22 |
from src import constants
|
23 |
from src.config import Config, logger
|
24 |
from src.custom_types import Option, OptionMap
|
25 |
+
from src.database import AsyncDBSessionMaker
|
26 |
from src.integrations import (
|
27 |
AnthropicError,
|
28 |
ElevenLabsError,
|
|
|
54 |
|
55 |
# leaderboard update state
|
56 |
self._leaderboard_data: List[List[str]] = [[]]
|
57 |
+
self._battle_counts_data: List[List[str]] = [[]]
|
58 |
+
self._win_rates_data: List[List[str]] = [[]]
|
59 |
self._leaderboard_cache_hash: Optional[str] = None
|
60 |
self._last_leaderboard_update_time: float = 0.0
|
61 |
self._min_refresh_interval = 30
|
|
|
79 |
return False
|
80 |
|
81 |
# Fetch the latest data
|
82 |
+
(
|
83 |
+
latest_leaderboard_data,
|
84 |
+
latest_battle_counts_data,
|
85 |
+
latest_win_rates_data
|
86 |
+
) = await get_leaderboard_data(self.db_session_maker)
|
87 |
|
88 |
# Generate a hash of the new data to check if it's changed
|
89 |
data_str = json.dumps(str(latest_leaderboard_data))
|
|
|
96 |
|
97 |
# Update the cache and timestamp
|
98 |
self._leaderboard_data = latest_leaderboard_data
|
99 |
+
self._battle_counts_data = latest_battle_counts_data
|
100 |
+
self._win_rates_data = latest_win_rates_data
|
101 |
self._leaderboard_cache_hash = data_hash
|
102 |
self._last_leaderboard_update_time = current_time
|
103 |
logger.info("Leaderboard data updated successfully.")
|
|
|
338 |
gr.update(value=character_description), # Update character description
|
339 |
)
|
340 |
|
341 |
+
async def _refresh_leaderboard(self, force: bool = False) -> Tuple[gr.DataFrame, gr.DataFrame, gr.DataFrame]:
|
342 |
"""
|
343 |
Asynchronously fetches and formats the latest leaderboard data.
|
344 |
|
|
|
346 |
force (bool): If True, bypass time-based throttling.
|
347 |
|
348 |
Returns:
|
349 |
+
tuple: Updated DataFrames or gr.skip() if no update needed
|
350 |
"""
|
351 |
data_updated = await self._update_leaderboard_data(force=force)
|
352 |
|
353 |
if not self._leaderboard_data:
|
354 |
raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
|
355 |
|
356 |
+
if data_updated or force:
|
357 |
+
return (
|
358 |
+
gr.update(value=self._leaderboard_data),
|
359 |
+
gr.update(value=self._battle_counts_data),
|
360 |
+
gr.update(value=self._win_rates_data)
|
361 |
+
)
|
362 |
+
return gr.skip(), gr.skip(), gr.skip()
|
363 |
|
364 |
async def _handle_tab_select(self, evt: gr.SelectData):
|
365 |
"""
|
|
|
369 |
evt (gr.SelectData): Event data containing information about the selected tab
|
370 |
|
371 |
Returns:
|
372 |
+
tuple: Updates for the three tables if data changed, otherwise skip
|
373 |
"""
|
|
|
374 |
if evt.value == "Leaderboard":
|
375 |
return await self._refresh_leaderboard(force=False)
|
376 |
+
return gr.skip(), gr.skip(), gr.skip()
|
377 |
|
378 |
def _disable_ui(self) -> Tuple[
|
379 |
gr.Button,
|
|
|
919 |
elem_id="leaderboard-table"
|
920 |
)
|
921 |
|
922 |
+
with gr.Column():
|
923 |
+
gr.HTML(
|
924 |
+
value="""
|
925 |
+
<h2 style="padding-top: 12px;" class="tab-header">📊 Head-to-Head Matchups</h2>
|
926 |
+
<p style="padding-left: 8px; width: 80%;">
|
927 |
+
These tables show how each provider performs against others in direct comparisons.
|
928 |
+
The first table shows the total number of comparisons between each pair of providers.
|
929 |
+
The second table shows the win rate (percentage) of the row provider against the column provider.
|
930 |
+
</p>
|
931 |
+
""",
|
932 |
+
padding=False
|
933 |
+
)
|
934 |
+
|
935 |
+
with gr.Row(equal_height=True):
|
936 |
+
with gr.Column(min_width=420):
|
937 |
+
battle_counts_table = gr.DataFrame(
|
938 |
+
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
939 |
+
datatype=["html", "html", "html", "html"],
|
940 |
+
column_widths=[132, 132, 132, 132],
|
941 |
+
value=self._battle_counts_data,
|
942 |
+
interactive=False,
|
943 |
+
)
|
944 |
+
with gr.Column(min_width=420):
|
945 |
+
win_rates_table = gr.DataFrame(
|
946 |
+
headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
|
947 |
+
datatype=["html", "html", "html", "html"],
|
948 |
+
column_widths=[132, 132, 132, 132],
|
949 |
+
value=self._win_rates_data,
|
950 |
+
interactive=False,
|
951 |
+
)
|
952 |
+
|
953 |
with gr.Accordion(label="Citation", open=False):
|
954 |
with gr.Column(variant="panel"):
|
955 |
with gr.Column(variant="panel"):
|
|
|
1006 |
|
1007 |
# Wrapper for the async refresh function
|
1008 |
async def async_refresh_handler():
|
1009 |
+
leaderboard_update, battle_counts_update, win_rates_update = await self._refresh_leaderboard(force=True)
|
1010 |
+
return leaderboard_update, battle_counts_update, win_rates_update
|
1011 |
|
1012 |
# Handler to re-enable the button after a refresh
|
1013 |
def reenable_button():
|
|
|
1022 |
).then(
|
1023 |
fn=async_refresh_handler,
|
1024 |
inputs=[],
|
1025 |
+
outputs=[leaderboard_table, battle_counts_table, win_rates_table] # Update all three tables
|
1026 |
).then(
|
1027 |
fn=reenable_button,
|
1028 |
inputs=[],
|
1029 |
outputs=[refresh_button]
|
1030 |
)
|
1031 |
|
1032 |
+
return leaderboard_table, battle_counts_table, win_rates_table
|
1033 |
|
1034 |
async def build_gradio_interface(self) -> gr.Blocks:
|
1035 |
"""
|
|
|
1046 |
with gr.TabItem("Arena"):
|
1047 |
self._build_arena_section()
|
1048 |
with gr.TabItem("Leaderboard"):
|
1049 |
+
leaderboard_table, battle_counts_table, win_rates_table = self._build_leaderboard_section()
|
1050 |
|
1051 |
tabs.select(
|
1052 |
fn=self._handle_tab_select,
|
1053 |
inputs=[],
|
1054 |
+
outputs=[leaderboard_table, battle_counts_table, win_rates_table],
|
1055 |
)
|
1056 |
|
1057 |
logger.debug("Gradio interface built successfully")
|
@@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
|
13 |
|
14 |
# Local Application Imports
|
15 |
from src.config import Config, logger
|
16 |
-
from src.database
|
17 |
|
18 |
|
19 |
async def init_tables():
|
|
|
13 |
|
14 |
# Local Application Imports
|
15 |
from src.config import Config, logger
|
16 |
+
from src.database import Base
|
17 |
|
18 |
|
19 |
async def init_tables():
|
@@ -34,7 +34,7 @@ from sqlalchemy import text
|
|
34 |
|
35 |
# Local Application Imports
|
36 |
from src.config import Config, logger
|
37 |
-
from src.database
|
38 |
|
39 |
|
40 |
async def test_connection_async():
|
|
|
34 |
|
35 |
# Local Application Imports
|
36 |
from src.config import Config, logger
|
37 |
+
from src.database import engine, init_db
|
38 |
|
39 |
|
40 |
async def test_connection_async():
|
@@ -23,14 +23,20 @@ from src import constants
|
|
23 |
from src.config import Config, logger
|
24 |
from src.custom_types import (
|
25 |
ComparisonType,
|
|
|
26 |
Option,
|
27 |
OptionKey,
|
28 |
OptionMap,
|
29 |
TTSProviderName,
|
30 |
VotingResults,
|
31 |
)
|
32 |
-
from src.database import
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
@@ -374,7 +380,7 @@ async def _persist_vote(db_session_maker: AsyncDBSessionMaker, voting_results: V
|
|
374 |
session = await _create_db_session(db_session_maker)
|
375 |
_log_voting_results(voting_results)
|
376 |
try:
|
377 |
-
await
|
378 |
except Exception as e:
|
379 |
# Log the error with traceback
|
380 |
logger.error(f"Failed to create vote record: {e}", exc_info=True)
|
@@ -434,49 +440,159 @@ async def submit_voting_results(
|
|
434 |
logger.error(f"Background task error in submit_voting_results: {e}", exc_info=True)
|
435 |
|
436 |
|
437 |
-
async def get_leaderboard_data(
|
|
|
|
|
438 |
"""
|
439 |
-
Fetches leaderboard data from voting results database
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
440 |
|
441 |
Returns:
|
442 |
-
|
443 |
-
|
|
|
|
|
444 |
"""
|
445 |
# Create session
|
446 |
session = await _create_db_session(db_session_maker)
|
447 |
try:
|
448 |
-
|
|
|
|
|
|
|
449 |
logger.info("Fetched leaderboard data successfully.")
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
href="{constants.TTS_PROVIDER_LINKS[row[1]]["provider_link"]}"
|
457 |
-
target="_blank"
|
458 |
-
class="provider-link"
|
459 |
-
>{row[1]}</a>
|
460 |
-
""",
|
461 |
-
f"""<a
|
462 |
-
href="{constants.TTS_PROVIDER_LINKS[row[1]]["model_link"]}"
|
463 |
-
target="_blank"
|
464 |
-
class="provider-link"
|
465 |
-
>{row[2]}</a>
|
466 |
-
""",
|
467 |
-
f'<p style="text-align: center;">{row[3]}</p>',
|
468 |
-
f'<p style="text-align: center;">{row[4]}</p>',
|
469 |
-
] for row in leaderboard_data
|
470 |
-
]
|
471 |
except Exception as e:
|
472 |
# Log the error with traceback
|
473 |
logger.error(f"Failed to fetch leaderboard data: {e}", exc_info=True)
|
474 |
-
return []
|
475 |
finally:
|
476 |
# Always ensure the session is closed
|
477 |
if session is not None:
|
478 |
await session.close()
|
479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
|
481 |
def validate_env_var(var_name: str) -> str:
|
482 |
"""
|
|
|
23 |
from src.config import Config, logger
|
24 |
from src.custom_types import (
|
25 |
ComparisonType,
|
26 |
+
LeaderboardEntry,
|
27 |
Option,
|
28 |
OptionKey,
|
29 |
OptionMap,
|
30 |
TTSProviderName,
|
31 |
VotingResults,
|
32 |
)
|
33 |
+
from src.database import (
|
34 |
+
AsyncDBSessionMaker,
|
35 |
+
create_vote,
|
36 |
+
get_head_to_head_battle_stats,
|
37 |
+
get_head_to_head_win_rate_stats,
|
38 |
+
get_leaderboard_stats,
|
39 |
+
)
|
40 |
|
41 |
|
42 |
def truncate_text(text: str, max_length: int = 50) -> str:
|
|
|
380 |
session = await _create_db_session(db_session_maker)
|
381 |
_log_voting_results(voting_results)
|
382 |
try:
|
383 |
+
await create_vote(cast(AsyncSession, session), voting_results)
|
384 |
except Exception as e:
|
385 |
# Log the error with traceback
|
386 |
logger.error(f"Failed to create vote record: {e}", exc_info=True)
|
|
|
440 |
logger.error(f"Background task error in submit_voting_results: {e}", exc_info=True)
|
441 |
|
442 |
|
443 |
+
async def get_leaderboard_data(
|
444 |
+
db_session_maker: AsyncDBSessionMaker
|
445 |
+
) -> Tuple[List[List[str]], List[List[str]], List[List[str]]]:
|
446 |
"""
|
447 |
+
Fetches and formats all leaderboard data from the voting results database.
|
448 |
+
|
449 |
+
This function retrieves three different datasets:
|
450 |
+
1. Provider rankings with overall performance metrics
|
451 |
+
2. Head-to-head battle counts between providers
|
452 |
+
3. Win rate percentages for each provider against others
|
453 |
+
|
454 |
+
Args:
|
455 |
+
db_session_maker (AsyncDBSessionMaker): Factory function for creating async database sessions.
|
456 |
|
457 |
Returns:
|
458 |
+
Tuple containing three datasets, each as List[List[str]]:
|
459 |
+
- leaderboard_data: Provider rankings with performance metrics
|
460 |
+
- battle_counts_data: Number of comparisons between each provider pair
|
461 |
+
- win_rate_data: Win percentages in head-to-head matchups
|
462 |
"""
|
463 |
# Create session
|
464 |
session = await _create_db_session(db_session_maker)
|
465 |
try:
|
466 |
+
leaderboard_data_raw = await get_leaderboard_stats(cast(AsyncSession, session))
|
467 |
+
battle_counts_data_raw = await get_head_to_head_battle_stats(cast(AsyncSession, session))
|
468 |
+
win_rate_data_raw = await get_head_to_head_win_rate_stats(cast(AsyncSession, session))
|
469 |
+
|
470 |
logger.info("Fetched leaderboard data successfully.")
|
471 |
+
|
472 |
+
leaderboard_data = _format_leaderboard_data(leaderboard_data_raw)
|
473 |
+
battle_counts_data = _format_battle_counts_data(battle_counts_data_raw)
|
474 |
+
win_rate_data = _format_win_rate_data(win_rate_data_raw)
|
475 |
+
|
476 |
+
return leaderboard_data, battle_counts_data, win_rate_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
477 |
except Exception as e:
|
478 |
# Log the error with traceback
|
479 |
logger.error(f"Failed to fetch leaderboard data: {e}", exc_info=True)
|
480 |
+
return [[]], [[]], [[]]
|
481 |
finally:
|
482 |
# Always ensure the session is closed
|
483 |
if session is not None:
|
484 |
await session.close()
|
485 |
|
486 |
+
def _format_leaderboard_data(leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
|
487 |
+
"""
|
488 |
+
Formats raw leaderboard data for display in the UI.
|
489 |
+
|
490 |
+
Converts LeaderboardEntry objects into HTML-formatted strings with appropriate
|
491 |
+
styling and links for provider and model information.
|
492 |
+
|
493 |
+
Args:
|
494 |
+
leaderboard_data_raw (List[LeaderboardEntry]): Raw leaderboard data from the database.
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
List[List[str]]: Formatted HTML strings for each cell in the leaderboard table.
|
498 |
+
"""
|
499 |
+
return [
|
500 |
+
[
|
501 |
+
f'<p style="text-align: center;">{row[0]}</p>',
|
502 |
+
f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["provider_link"]}"
|
503 |
+
target="_blank"
|
504 |
+
class="provider-link"
|
505 |
+
>{row[1]}</a>
|
506 |
+
""",
|
507 |
+
f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["model_link"]}"
|
508 |
+
target="_blank"
|
509 |
+
class="provider-link"
|
510 |
+
>{row[2]}</a>
|
511 |
+
""",
|
512 |
+
f'<p style="text-align: center;">{row[3]}</p>',
|
513 |
+
f'<p style="text-align: center;">{row[4]}</p>',
|
514 |
+
] for row in leaderboard_data_raw
|
515 |
+
]
|
516 |
+
|
517 |
+
|
518 |
+
def _format_battle_counts_data(battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
|
519 |
+
"""
|
520 |
+
Formats battle count data into a matrix format for the UI.
|
521 |
+
|
522 |
+
Creates a provider-by-provider matrix showing the number of direct comparisons
|
523 |
+
between each pair of providers. Diagonal cells show dashes as providers aren't
|
524 |
+
compared against themselves.
|
525 |
+
|
526 |
+
Args:
|
527 |
+
battle_counts_data_raw (List[List[str]]): Raw battle count data from the database,
|
528 |
+
where each inner list contains [comparison_type, count].
|
529 |
+
|
530 |
+
Returns:
|
531 |
+
List[List[str]]: HTML-formatted matrix of battle counts between providers.
|
532 |
+
"""
|
533 |
+
battle_counts_dict = {item[0]: item[1] for item in battle_counts_data_raw}
|
534 |
+
# Create canonical comparison keys based on your expected database formats
|
535 |
+
comparison_keys = {
|
536 |
+
("Hume AI", "OpenAI"): "Hume AI - OpenAI",
|
537 |
+
("Hume AI", "ElevenLabs"): "Hume AI - ElevenLabs",
|
538 |
+
("OpenAI", "ElevenLabs"): "OpenAI - ElevenLabs"
|
539 |
+
}
|
540 |
+
return [
|
541 |
+
[
|
542 |
+
f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
|
543 |
+
] + [
|
544 |
+
f"""
|
545 |
+
<p style="text-align: center;">
|
546 |
+
{"-" if row_provider == col_provider
|
547 |
+
else battle_counts_dict.get(
|
548 |
+
comparison_keys.get((row_provider, col_provider)) or
|
549 |
+
comparison_keys.get((col_provider, row_provider), "unknown"),
|
550 |
+
"0"
|
551 |
+
)
|
552 |
+
}
|
553 |
+
</p>
|
554 |
+
""" for col_provider in constants.TTS_PROVIDERS
|
555 |
+
]
|
556 |
+
for row_provider in constants.TTS_PROVIDERS
|
557 |
+
]
|
558 |
+
|
559 |
+
|
560 |
+
def _format_win_rate_data(win_rate_data_raw: List[List[str]]) -> List[List[str]]:
|
561 |
+
"""
|
562 |
+
Formats win rate data into a matrix format for the UI.
|
563 |
+
|
564 |
+
Creates a provider-by-provider matrix showing the percentage of times the row
|
565 |
+
provider won against the column provider. Diagonal cells show dashes as
|
566 |
+
providers aren't compared against themselves.
|
567 |
+
|
568 |
+
Args:
|
569 |
+
win_rate_data_raw (List[List[str]]): Raw win rate data from the database,
|
570 |
+
where each inner list contains [comparison_type, first_win_rate, second_win_rate].
|
571 |
+
|
572 |
+
Returns:
|
573 |
+
List[List[str]]: HTML-formatted matrix of win rates between providers.
|
574 |
+
"""
|
575 |
+
# Create a clean lookup dictionary with provider pairs as keys
|
576 |
+
win_rates = {}
|
577 |
+
for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
|
578 |
+
provider1, provider2 = comparison_type.split(" - ")
|
579 |
+
win_rates[(provider1, provider2)] = first_win_rate
|
580 |
+
win_rates[(provider2, provider1)] = second_win_rate
|
581 |
+
|
582 |
+
return [
|
583 |
+
[
|
584 |
+
f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
|
585 |
+
] + [
|
586 |
+
f"""
|
587 |
+
<p style="text-align: center;">
|
588 |
+
{"-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")}
|
589 |
+
</p>
|
590 |
+
"""
|
591 |
+
for col_provider in constants.TTS_PROVIDERS
|
592 |
+
]
|
593 |
+
for row_provider in constants.TTS_PROVIDERS
|
594 |
+
]
|
595 |
+
|
596 |
|
597 |
def validate_env_var(var_name: str) -> str:
|
598 |
"""
|