Zachary Greathouse twitchard commited on
Commit
97b3bfd
·
unverified ·
1 Parent(s): 9ed181c

Zg/add head to head results (#19)

Browse files

* Add OpenAI python SDK to dependencies

* Fix Anthropic clean API Error message.

* Update constants and custom types associated with TTS providers to include OpenAI

* Add OpenAI integration

* Update logic for selecting providers, add OpenAI tts to UI

* Fix typo in openai_api.py

* Update docstrings in openai_api.py

* Update leaderboard results query to include OpenAI results

* Add citation

* Adjust padding in UI components

* Adjust padding in UI components in citation

* Add transitive dependency override for sounddevice in pyproject.toml

* remove sounddevice

* Add warning toast for custom text inputs

* Improve leaderboard results query to account for zero records, and update to only include relevant comparison types for each provider.

* Fix database package imports and add head-to-head comparison queries

* Add utils for fetching head-to-head comparison data

* Updates UI to include head-to-head comparison tables, updates leaderboard data fetching in UI to include comparison data

---------

Co-authored-by: twitchard <[email protected]>

src/constants.py CHANGED
@@ -10,7 +10,6 @@ from typing import Dict, List
10
  # Third-Party Library Imports
11
  from src.custom_types import (
12
  ComparisonType,
13
- LeaderboardEntry,
14
  OptionKey,
15
  OptionLabel,
16
  TTSProviderName,
@@ -26,7 +25,7 @@ HUME_AI: TTSProviderName = "Hume AI"
26
  ELEVENLABS: TTSProviderName = "ElevenLabs"
27
  OPENAI: TTSProviderName = "OpenAI"
28
 
29
- TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "ElevenLabs", "OpenAI"]
30
  TTS_PROVIDER_LINKS = {
31
  "Hume AI": {
32
  "provider_link": "https://hume.ai/",
@@ -169,10 +168,3 @@ META_TAGS: List[Dict[str, str]] = [
169
  'content': '/static/arena-opengraph-logo.png'
170
  }
171
  ]
172
-
173
- # Reflects and empty leaderboard state
174
- DEFAULT_LEADERBOARD: List[LeaderboardEntry] = [
175
- LeaderboardEntry("1", "", "", "0%", "0"),
176
- LeaderboardEntry("2", "", "", "0%", "0"),
177
- LeaderboardEntry("3", "", "", "0%", "0"),
178
- ]
 
10
  # Third-Party Library Imports
11
  from src.custom_types import (
12
  ComparisonType,
 
13
  OptionKey,
14
  OptionLabel,
15
  TTSProviderName,
 
25
  ELEVENLABS: TTSProviderName = "ElevenLabs"
26
  OPENAI: TTSProviderName = "OpenAI"
27
 
28
+ TTS_PROVIDERS: List[TTSProviderName] = ["Hume AI", "OpenAI", "ElevenLabs"]
29
  TTS_PROVIDER_LINKS = {
30
  "Hume AI": {
31
  "provider_link": "https://hume.ai/",
 
168
  'content': '/static/arena-opengraph-logo.png'
169
  }
170
  ]
 
 
 
 
 
 
 
src/database/__init__.py CHANGED
@@ -1,10 +1,15 @@
1
- from .crud import create_vote
2
  from .database import AsyncDBSessionMaker, Base, engine, init_db
 
3
 
4
  __all__ = [
5
  "AsyncDBSessionMaker",
6
  "Base",
 
7
  "create_vote",
8
  "engine",
 
 
 
9
  "init_db",
10
  ]
 
1
+ from .crud import create_vote, get_head_to_head_battle_stats, get_head_to_head_win_rate_stats, get_leaderboard_stats
2
  from .database import AsyncDBSessionMaker, Base, engine, init_db
3
+ from .models import VoteResult
4
 
5
  __all__ = [
6
  "AsyncDBSessionMaker",
7
  "Base",
8
+ "VoteResult",
9
  "create_vote",
10
  "engine",
11
+ "get_head_to_head_battle_stats",
12
+ "get_head_to_head_win_rate_stats",
13
+ "get_leaderboard_stats",
14
  "init_db",
15
  ]
src/database/crud.py CHANGED
@@ -5,6 +5,9 @@ This module defines the operations for the Expressive TTS Arena project's databa
5
  Since vote records are never updated or deleted, only functions to create and read votes are provided.
6
  """
7
 
 
 
 
8
  # Third-Party Library Imports
9
  from sqlalchemy import text
10
  from sqlalchemy.exc import SQLAlchemyError
@@ -12,7 +15,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
12
 
13
  # Local Application Imports
14
  from src.config import logger
15
- from src.constants import DEFAULT_LEADERBOARD
16
  from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
17
  from src.database.models import VoteResult
18
 
@@ -83,6 +85,12 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
83
  LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
84
  provider name, model name, win rate, and total votes.
85
  """
 
 
 
 
 
 
86
  try:
87
  query = text(
88
  """
@@ -137,6 +145,10 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
137
  result = await db.execute(query)
138
  rows = result.fetchall()
139
 
 
 
 
 
140
  # Format the data for the leaderboard
141
  leaderboard_data = []
142
  for i, row in enumerate(rows, 1):
@@ -150,16 +162,161 @@ async def get_leaderboard_stats(db: AsyncSession) -> LeaderboardTableEntries:
150
  )
151
  leaderboard_data.append(leaderboard_entry)
152
 
153
- # If no data was found, return default entries
154
- if not leaderboard_data:
155
- return DEFAULT_LEADERBOARD
156
-
157
  return leaderboard_data
158
 
159
  except SQLAlchemyError as e:
160
  logger.error(f"Database error while fetching leaderboard stats: {e}")
161
- return DEFAULT_LEADERBOARD
162
  except Exception as e:
163
  logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
164
- return DEFAULT_LEADERBOARD
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  Since vote records are never updated or deleted, only functions to create and read votes are provided.
6
  """
7
 
8
+ # Standard Library Imports
9
+ from typing import List
10
+
11
  # Third-Party Library Imports
12
  from sqlalchemy import text
13
  from sqlalchemy.exc import SQLAlchemyError
 
15
 
16
  # Local Application Imports
17
  from src.config import logger
 
18
  from src.custom_types import LeaderboardEntry, LeaderboardTableEntries, VotingResults
19
  from src.database.models import VoteResult
20
 
 
85
  LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank,
86
  provider name, model name, win rate, and total votes.
87
  """
88
+ default_leaderboard = [
89
+ LeaderboardEntry("1", "", "", "0%", "0"),
90
+ LeaderboardEntry("2", "", "", "0%", "0"),
91
+ LeaderboardEntry("3", "", "", "0%", "0"),
92
+ ]
93
+
94
  try:
95
  query = text(
96
  """
 
145
  result = await db.execute(query)
146
  rows = result.fetchall()
147
 
148
+ # If no rows, return default
149
+ if not rows:
150
+ return default_leaderboard
151
+
152
  # Format the data for the leaderboard
153
  leaderboard_data = []
154
  for i, row in enumerate(rows, 1):
 
162
  )
163
  leaderboard_data.append(leaderboard_entry)
164
 
 
 
 
 
165
  return leaderboard_data
166
 
167
  except SQLAlchemyError as e:
168
  logger.error(f"Database error while fetching leaderboard stats: {e}")
169
+ return default_leaderboard
170
  except Exception as e:
171
  logger.error(f"Unexpected error while fetching leaderboard stats: {e}")
172
+ return default_leaderboard
173
+
174
+
175
+ async def get_head_to_head_battle_stats(db: AsyncSession) -> List[List[str]]:
176
+ """
177
+ Fetches the total number of voting results for each comparison type (excluding "Hume AI - Hume AI").
178
+
179
+ Args:
180
+ db (AsyncSession): The SQLAlchemy async database session.
181
+
182
+ Returns:
183
+ List[List[str]]: A list of lists, where each inner list contains the comparison type and the count.
184
+ """
185
+ default_counts = [
186
+ ["Hume AI - OpenAI", "0"],
187
+ ["Hume AI - ElevenLabs", "0"],
188
+ ["OpenAI - ElevenLabs", "0"],
189
+ ]
190
 
191
+ try:
192
+ query = text(
193
+ """
194
+ SELECT
195
+ comparison_type,
196
+ COUNT(*) as total
197
+ FROM vote_results
198
+ WHERE comparison_type != 'Hume AI - Hume AI'
199
+ GROUP BY comparison_type
200
+ ORDER BY comparison_type;
201
+ """
202
+ )
203
+
204
+ result = await db.execute(query)
205
+ rows = result.fetchall()
206
+
207
+ # If no rows, return default
208
+ if not rows:
209
+ return default_counts
210
+
211
+ # Format the results
212
+ formatted_results = []
213
+ for row in rows:
214
+ comparison_type, count = row
215
+ formatted_results.append([comparison_type, str(count)])
216
+
217
+ # Make sure all expected comparison types are included
218
+ expected_types = {"Hume AI - OpenAI", "Hume AI - ElevenLabs", "OpenAI - ElevenLabs"}
219
+ found_types = {row[0] for row in formatted_results}
220
+
221
+ # Add missing types with zero counts
222
+ for type_name in expected_types - found_types:
223
+ formatted_results.append([type_name, "0"])
224
+
225
+ # Sort the results by comparison type
226
+ formatted_results.sort(key=lambda x: x[0])
227
+
228
+ return formatted_results
229
+
230
+ except SQLAlchemyError as e:
231
+ logger.error(f"Database error while fetching comparison counts: {e}")
232
+ return default_counts
233
+ except Exception as e:
234
+ logger.error(f"Unexpected error while fetching comparison counts: {e}")
235
+ return default_counts
236
+
237
+
238
+ async def get_head_to_head_win_rate_stats(db: AsyncSession) -> List[List[str]]:
239
+ """
240
+ Calculates the win rate for each provider against the other in head-to-head comparisons.
241
+
242
+ Args:
243
+ db (AsyncSession): The SQLAlchemy async database session.
244
+
245
+ Returns:
246
+ List[List[str]]: A list of lists, where each inner list contains:
247
+ - The comparison type
248
+ - The win rate of the first provider (the one named first in the comparison type)
249
+ - The win rate of the second provider (the one named second in the comparison type)
250
+ """
251
+ default_win_rates = [
252
+ ["Hume AI - OpenAI", "0%", "0%"],
253
+ ["Hume AI - ElevenLabs", "0%", "0%"],
254
+ ["OpenAI - ElevenLabs", "0%", "0%"],
255
+ ]
256
+
257
+ try:
258
+ query = text(
259
+ """
260
+ SELECT
261
+ comparison_type,
262
+ CASE WHEN COUNT(*) > 0
263
+ THEN ROUND(SUM(CASE
264
+ WHEN comparison_type = 'Hume AI - OpenAI' AND winning_provider = 'Hume AI' THEN 1
265
+ WHEN comparison_type = 'Hume AI - ElevenLabs' AND winning_provider = 'Hume AI' THEN 1
266
+ WHEN comparison_type = 'OpenAI - ElevenLabs' AND winning_provider = 'OpenAI' THEN 1
267
+ ELSE 0
268
+ END) * 100.0 / COUNT(*), 2)
269
+ ELSE 0
270
+ END as first_provider_win_rate,
271
+ CASE WHEN COUNT(*) > 0
272
+ THEN ROUND(SUM(CASE
273
+ WHEN comparison_type = 'Hume AI - OpenAI' AND winning_provider = 'OpenAI' THEN 1
274
+ WHEN comparison_type = 'Hume AI - ElevenLabs' AND winning_provider = 'ElevenLabs' THEN 1
275
+ WHEN comparison_type = 'OpenAI - ElevenLabs' AND winning_provider = 'ElevenLabs' THEN 1
276
+ ELSE 0
277
+ END) * 100.0 / COUNT(*), 2)
278
+ ELSE 0
279
+ END as second_provider_win_rate
280
+ FROM vote_results
281
+ WHERE comparison_type != 'Hume AI - Hume AI'
282
+ GROUP BY comparison_type
283
+ ORDER BY comparison_type;
284
+ """
285
+ )
286
+
287
+ result = await db.execute(query)
288
+ rows = result.fetchall()
289
+
290
+ # If no rows, return default
291
+ if not rows:
292
+ return default_win_rates
293
+
294
+ # Format the results
295
+ formatted_results = []
296
+ for row in rows:
297
+ comparison_type, first_provider_win_rate, second_provider_win_rate = row
298
+ formatted_results.append([
299
+ comparison_type,
300
+ f"{first_provider_win_rate}%",
301
+ f"{second_provider_win_rate}%"
302
+ ])
303
+
304
+ # Make sure all expected comparison types are included
305
+ expected_types = {"Hume AI - OpenAI", "Hume AI - ElevenLabs", "OpenAI - ElevenLabs"}
306
+ found_types = {row[0] for row in formatted_results}
307
+
308
+ # Add missing types with zero win rates
309
+ for type_name in expected_types - found_types:
310
+ formatted_results.append([type_name, "0%", "0%"])
311
+
312
+ # Sort the results by comparison type
313
+ formatted_results.sort(key=lambda x: x[0])
314
+
315
+ return formatted_results
316
+
317
+ except SQLAlchemyError as e:
318
+ logger.error(f"Database error while fetching provider win rates: {e}")
319
+ return default_win_rates
320
+ except Exception as e:
321
+ logger.error(f"Unexpected error while fetching provider win rates: {e}")
322
+ return default_win_rates
src/frontend.py CHANGED
@@ -22,7 +22,7 @@ import gradio as gr
22
  from src import constants
23
  from src.config import Config, logger
24
  from src.custom_types import Option, OptionMap
25
- from src.database.database import AsyncDBSessionMaker
26
  from src.integrations import (
27
  AnthropicError,
28
  ElevenLabsError,
@@ -54,6 +54,8 @@ class Frontend:
54
 
55
  # leaderboard update state
56
  self._leaderboard_data: List[List[str]] = [[]]
 
 
57
  self._leaderboard_cache_hash: Optional[str] = None
58
  self._last_leaderboard_update_time: float = 0.0
59
  self._min_refresh_interval = 30
@@ -77,7 +79,11 @@ class Frontend:
77
  return False
78
 
79
  # Fetch the latest data
80
- latest_leaderboard_data = await get_leaderboard_data(self.db_session_maker)
 
 
 
 
81
 
82
  # Generate a hash of the new data to check if it's changed
83
  data_str = json.dumps(str(latest_leaderboard_data))
@@ -90,6 +96,8 @@ class Frontend:
90
 
91
  # Update the cache and timestamp
92
  self._leaderboard_data = latest_leaderboard_data
 
 
93
  self._leaderboard_cache_hash = data_hash
94
  self._last_leaderboard_update_time = current_time
95
  logger.info("Leaderboard data updated successfully.")
@@ -330,7 +338,7 @@ class Frontend:
330
  gr.update(value=character_description), # Update character description
331
  )
332
 
333
- async def _refresh_leaderboard(self, force: bool = False) -> gr.DataFrame:
334
  """
335
  Asynchronously fetches and formats the latest leaderboard data.
336
 
@@ -338,17 +346,20 @@ class Frontend:
338
  force (bool): If True, bypass time-based throttling.
339
 
340
  Returns:
341
- gr.DataFrame: Updated DataFrame or gr.skip() if no update needed
342
  """
343
  data_updated = await self._update_leaderboard_data(force=force)
344
 
345
  if not self._leaderboard_data:
346
  raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
347
 
348
- # Only return an update if the data changed or force=True
349
- if data_updated:
350
- return gr.update(value=self._leaderboard_data)
351
- return gr.skip()
 
 
 
352
 
353
  async def _handle_tab_select(self, evt: gr.SelectData):
354
  """
@@ -358,12 +369,11 @@ class Frontend:
358
  evt (gr.SelectData): Event data containing information about the selected tab
359
 
360
  Returns:
361
- gr.update or gr.skip: Update for the leaderboard table if data changed, otherwise skip
362
  """
363
- # Check if the selected tab is "Leaderboard" by name
364
  if evt.value == "Leaderboard":
365
  return await self._refresh_leaderboard(force=False)
366
- return gr.skip()
367
 
368
  def _disable_ui(self) -> Tuple[
369
  gr.Button,
@@ -909,6 +919,37 @@ class Frontend:
909
  elem_id="leaderboard-table"
910
  )
911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
  with gr.Accordion(label="Citation", open=False):
913
  with gr.Column(variant="panel"):
914
  with gr.Column(variant="panel"):
@@ -965,7 +1006,8 @@ class Frontend:
965
 
966
  # Wrapper for the async refresh function
967
  async def async_refresh_handler():
968
- return await self._refresh_leaderboard(force=True)
 
969
 
970
  # Handler to re-enable the button after a refresh
971
  def reenable_button():
@@ -980,14 +1022,14 @@ class Frontend:
980
  ).then(
981
  fn=async_refresh_handler,
982
  inputs=[],
983
- outputs=[leaderboard_table]
984
  ).then(
985
  fn=reenable_button,
986
  inputs=[],
987
  outputs=[refresh_button]
988
  )
989
 
990
- return leaderboard_table
991
 
992
  async def build_gradio_interface(self) -> gr.Blocks:
993
  """
@@ -1004,12 +1046,12 @@ class Frontend:
1004
  with gr.TabItem("Arena"):
1005
  self._build_arena_section()
1006
  with gr.TabItem("Leaderboard"):
1007
- leaderboard_table = self._build_leaderboard_section()
1008
 
1009
  tabs.select(
1010
  fn=self._handle_tab_select,
1011
  inputs=[],
1012
- outputs=[leaderboard_table],
1013
  )
1014
 
1015
  logger.debug("Gradio interface built successfully")
 
22
  from src import constants
23
  from src.config import Config, logger
24
  from src.custom_types import Option, OptionMap
25
+ from src.database import AsyncDBSessionMaker
26
  from src.integrations import (
27
  AnthropicError,
28
  ElevenLabsError,
 
54
 
55
  # leaderboard update state
56
  self._leaderboard_data: List[List[str]] = [[]]
57
+ self._battle_counts_data: List[List[str]] = [[]]
58
+ self._win_rates_data: List[List[str]] = [[]]
59
  self._leaderboard_cache_hash: Optional[str] = None
60
  self._last_leaderboard_update_time: float = 0.0
61
  self._min_refresh_interval = 30
 
79
  return False
80
 
81
  # Fetch the latest data
82
+ (
83
+ latest_leaderboard_data,
84
+ latest_battle_counts_data,
85
+ latest_win_rates_data
86
+ ) = await get_leaderboard_data(self.db_session_maker)
87
 
88
  # Generate a hash of the new data to check if it's changed
89
  data_str = json.dumps(str(latest_leaderboard_data))
 
96
 
97
  # Update the cache and timestamp
98
  self._leaderboard_data = latest_leaderboard_data
99
+ self._battle_counts_data = latest_battle_counts_data
100
+ self._win_rates_data = latest_win_rates_data
101
  self._leaderboard_cache_hash = data_hash
102
  self._last_leaderboard_update_time = current_time
103
  logger.info("Leaderboard data updated successfully.")
 
338
  gr.update(value=character_description), # Update character description
339
  )
340
 
341
+ async def _refresh_leaderboard(self, force: bool = False) -> Tuple[gr.DataFrame, gr.DataFrame, gr.DataFrame]:
342
  """
343
  Asynchronously fetches and formats the latest leaderboard data.
344
 
 
346
  force (bool): If True, bypass time-based throttling.
347
 
348
  Returns:
349
+ tuple: Updated DataFrames or gr.skip() if no update needed
350
  """
351
  data_updated = await self._update_leaderboard_data(force=force)
352
 
353
  if not self._leaderboard_data:
354
  raise gr.Error("Unable to retrieve leaderboard data. Please refresh the page or try again shortly.")
355
 
356
+ if data_updated or force:
357
+ return (
358
+ gr.update(value=self._leaderboard_data),
359
+ gr.update(value=self._battle_counts_data),
360
+ gr.update(value=self._win_rates_data)
361
+ )
362
+ return gr.skip(), gr.skip(), gr.skip()
363
 
364
  async def _handle_tab_select(self, evt: gr.SelectData):
365
  """
 
369
  evt (gr.SelectData): Event data containing information about the selected tab
370
 
371
  Returns:
372
+ tuple: Updates for the three tables if data changed, otherwise skip
373
  """
 
374
  if evt.value == "Leaderboard":
375
  return await self._refresh_leaderboard(force=False)
376
+ return gr.skip(), gr.skip(), gr.skip()
377
 
378
  def _disable_ui(self) -> Tuple[
379
  gr.Button,
 
919
  elem_id="leaderboard-table"
920
  )
921
 
922
+ with gr.Column():
923
+ gr.HTML(
924
+ value="""
925
+ <h2 style="padding-top: 12px;" class="tab-header">📊 Head-to-Head Matchups</h2>
926
+ <p style="padding-left: 8px; width: 80%;">
927
+ These tables show how each provider performs against others in direct comparisons.
928
+ The first table shows the total number of comparisons between each pair of providers.
929
+ The second table shows the win rate (percentage) of the row provider against the column provider.
930
+ </p>
931
+ """,
932
+ padding=False
933
+ )
934
+
935
+ with gr.Row(equal_height=True):
936
+ with gr.Column(min_width=420):
937
+ battle_counts_table = gr.DataFrame(
938
+ headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
939
+ datatype=["html", "html", "html", "html"],
940
+ column_widths=[132, 132, 132, 132],
941
+ value=self._battle_counts_data,
942
+ interactive=False,
943
+ )
944
+ with gr.Column(min_width=420):
945
+ win_rates_table = gr.DataFrame(
946
+ headers=["", "Hume AI", "OpenAI", "ElevenLabs"],
947
+ datatype=["html", "html", "html", "html"],
948
+ column_widths=[132, 132, 132, 132],
949
+ value=self._win_rates_data,
950
+ interactive=False,
951
+ )
952
+
953
  with gr.Accordion(label="Citation", open=False):
954
  with gr.Column(variant="panel"):
955
  with gr.Column(variant="panel"):
 
1006
 
1007
  # Wrapper for the async refresh function
1008
  async def async_refresh_handler():
1009
+ leaderboard_update, battle_counts_update, win_rates_update = await self._refresh_leaderboard(force=True)
1010
+ return leaderboard_update, battle_counts_update, win_rates_update
1011
 
1012
  # Handler to re-enable the button after a refresh
1013
  def reenable_button():
 
1022
  ).then(
1023
  fn=async_refresh_handler,
1024
  inputs=[],
1025
+ outputs=[leaderboard_table, battle_counts_table, win_rates_table] # Update all three tables
1026
  ).then(
1027
  fn=reenable_button,
1028
  inputs=[],
1029
  outputs=[refresh_button]
1030
  )
1031
 
1032
+ return leaderboard_table, battle_counts_table, win_rates_table
1033
 
1034
  async def build_gradio_interface(self) -> gr.Blocks:
1035
  """
 
1046
  with gr.TabItem("Arena"):
1047
  self._build_arena_section()
1048
  with gr.TabItem("Leaderboard"):
1049
+ leaderboard_table, battle_counts_table, win_rates_table = self._build_leaderboard_section()
1050
 
1051
  tabs.select(
1052
  fn=self._handle_tab_select,
1053
  inputs=[],
1054
+ outputs=[leaderboard_table, battle_counts_table, win_rates_table],
1055
  )
1056
 
1057
  logger.debug("Gradio interface built successfully")
src/scripts/init_db.py CHANGED
@@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
13
 
14
  # Local Application Imports
15
  from src.config import Config, logger
16
- from src.database.models import Base
17
 
18
 
19
  async def init_tables():
 
13
 
14
  # Local Application Imports
15
  from src.config import Config, logger
16
+ from src.database import Base
17
 
18
 
19
  async def init_tables():
src/scripts/test_db.py CHANGED
@@ -34,7 +34,7 @@ from sqlalchemy import text
34
 
35
  # Local Application Imports
36
  from src.config import Config, logger
37
- from src.database.database import engine, init_db
38
 
39
 
40
  async def test_connection_async():
 
34
 
35
  # Local Application Imports
36
  from src.config import Config, logger
37
+ from src.database import engine, init_db
38
 
39
 
40
  async def test_connection_async():
src/utils.py CHANGED
@@ -23,14 +23,20 @@ from src import constants
23
  from src.config import Config, logger
24
  from src.custom_types import (
25
  ComparisonType,
 
26
  Option,
27
  OptionKey,
28
  OptionMap,
29
  TTSProviderName,
30
  VotingResults,
31
  )
32
- from src.database import crud
33
- from src.database.database import AsyncDBSessionMaker
 
 
 
 
 
34
 
35
 
36
  def truncate_text(text: str, max_length: int = 50) -> str:
@@ -374,7 +380,7 @@ async def _persist_vote(db_session_maker: AsyncDBSessionMaker, voting_results: V
374
  session = await _create_db_session(db_session_maker)
375
  _log_voting_results(voting_results)
376
  try:
377
- await crud.create_vote(cast(AsyncSession, session), voting_results)
378
  except Exception as e:
379
  # Log the error with traceback
380
  logger.error(f"Failed to create vote record: {e}", exc_info=True)
@@ -434,49 +440,159 @@ async def submit_voting_results(
434
  logger.error(f"Background task error in submit_voting_results: {e}", exc_info=True)
435
 
436
 
437
- async def get_leaderboard_data(db_session_maker: AsyncDBSessionMaker) -> List[List[str]]:
 
 
438
  """
439
- Fetches leaderboard data from voting results database
 
 
 
 
 
 
 
 
440
 
441
  Returns:
442
- LeaderboardTableEntries: A list of LeaderboardEntry objects containing rank, provider name anchor tag, model
443
- name anchor tag, win rate, and total votes.
 
 
444
  """
445
  # Create session
446
  session = await _create_db_session(db_session_maker)
447
  try:
448
- leaderboard_data = await crud.get_leaderboard_stats(cast(AsyncSession, session))
 
 
 
449
  logger.info("Fetched leaderboard data successfully.")
450
- # return data formatted for the UI (adds links and styling)
451
- return [
452
- [
453
- f'<p style="text-align: center;">{row[0]}</p>',
454
- f"""
455
- <a
456
- href="{constants.TTS_PROVIDER_LINKS[row[1]]["provider_link"]}"
457
- target="_blank"
458
- class="provider-link"
459
- >{row[1]}</a>
460
- """,
461
- f"""<a
462
- href="{constants.TTS_PROVIDER_LINKS[row[1]]["model_link"]}"
463
- target="_blank"
464
- class="provider-link"
465
- >{row[2]}</a>
466
- """,
467
- f'<p style="text-align: center;">{row[3]}</p>',
468
- f'<p style="text-align: center;">{row[4]}</p>',
469
- ] for row in leaderboard_data
470
- ]
471
  except Exception as e:
472
  # Log the error with traceback
473
  logger.error(f"Failed to fetch leaderboard data: {e}", exc_info=True)
474
- return []
475
  finally:
476
  # Always ensure the session is closed
477
  if session is not None:
478
  await session.close()
479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
481
  def validate_env_var(var_name: str) -> str:
482
  """
 
23
  from src.config import Config, logger
24
  from src.custom_types import (
25
  ComparisonType,
26
+ LeaderboardEntry,
27
  Option,
28
  OptionKey,
29
  OptionMap,
30
  TTSProviderName,
31
  VotingResults,
32
  )
33
+ from src.database import (
34
+ AsyncDBSessionMaker,
35
+ create_vote,
36
+ get_head_to_head_battle_stats,
37
+ get_head_to_head_win_rate_stats,
38
+ get_leaderboard_stats,
39
+ )
40
 
41
 
42
  def truncate_text(text: str, max_length: int = 50) -> str:
 
380
  session = await _create_db_session(db_session_maker)
381
  _log_voting_results(voting_results)
382
  try:
383
+ await create_vote(cast(AsyncSession, session), voting_results)
384
  except Exception as e:
385
  # Log the error with traceback
386
  logger.error(f"Failed to create vote record: {e}", exc_info=True)
 
440
  logger.error(f"Background task error in submit_voting_results: {e}", exc_info=True)
441
 
442
 
443
+ async def get_leaderboard_data(
444
+ db_session_maker: AsyncDBSessionMaker
445
+ ) -> Tuple[List[List[str]], List[List[str]], List[List[str]]]:
446
  """
447
+ Fetches and formats all leaderboard data from the voting results database.
448
+
449
+ This function retrieves three different datasets:
450
+ 1. Provider rankings with overall performance metrics
451
+ 2. Head-to-head battle counts between providers
452
+ 3. Win rate percentages for each provider against others
453
+
454
+ Args:
455
+ db_session_maker (AsyncDBSessionMaker): Factory function for creating async database sessions.
456
 
457
  Returns:
458
+ Tuple containing three datasets, each as List[List[str]]:
459
+ - leaderboard_data: Provider rankings with performance metrics
460
+ - battle_counts_data: Number of comparisons between each provider pair
461
+ - win_rate_data: Win percentages in head-to-head matchups
462
  """
463
  # Create session
464
  session = await _create_db_session(db_session_maker)
465
  try:
466
+ leaderboard_data_raw = await get_leaderboard_stats(cast(AsyncSession, session))
467
+ battle_counts_data_raw = await get_head_to_head_battle_stats(cast(AsyncSession, session))
468
+ win_rate_data_raw = await get_head_to_head_win_rate_stats(cast(AsyncSession, session))
469
+
470
  logger.info("Fetched leaderboard data successfully.")
471
+
472
+ leaderboard_data = _format_leaderboard_data(leaderboard_data_raw)
473
+ battle_counts_data = _format_battle_counts_data(battle_counts_data_raw)
474
+ win_rate_data = _format_win_rate_data(win_rate_data_raw)
475
+
476
+ return leaderboard_data, battle_counts_data, win_rate_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  except Exception as e:
478
  # Log the error with traceback
479
  logger.error(f"Failed to fetch leaderboard data: {e}", exc_info=True)
480
+ return [[]], [[]], [[]]
481
  finally:
482
  # Always ensure the session is closed
483
  if session is not None:
484
  await session.close()
485
 
486
+ def _format_leaderboard_data(leaderboard_data_raw: List[LeaderboardEntry]) -> List[List[str]]:
487
+ """
488
+ Formats raw leaderboard data for display in the UI.
489
+
490
+ Converts LeaderboardEntry objects into HTML-formatted strings with appropriate
491
+ styling and links for provider and model information.
492
+
493
+ Args:
494
+ leaderboard_data_raw (List[LeaderboardEntry]): Raw leaderboard data from the database.
495
+
496
+ Returns:
497
+ List[List[str]]: Formatted HTML strings for each cell in the leaderboard table.
498
+ """
499
+ return [
500
+ [
501
+ f'<p style="text-align: center;">{row[0]}</p>',
502
+ f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["provider_link"]}"
503
+ target="_blank"
504
+ class="provider-link"
505
+ >{row[1]}</a>
506
+ """,
507
+ f"""<a href="{constants.TTS_PROVIDER_LINKS[row[1]]["model_link"]}"
508
+ target="_blank"
509
+ class="provider-link"
510
+ >{row[2]}</a>
511
+ """,
512
+ f'<p style="text-align: center;">{row[3]}</p>',
513
+ f'<p style="text-align: center;">{row[4]}</p>',
514
+ ] for row in leaderboard_data_raw
515
+ ]
516
+
517
+
518
+ def _format_battle_counts_data(battle_counts_data_raw: List[List[str]]) -> List[List[str]]:
519
+ """
520
+ Formats battle count data into a matrix format for the UI.
521
+
522
+ Creates a provider-by-provider matrix showing the number of direct comparisons
523
+ between each pair of providers. Diagonal cells show dashes as providers aren't
524
+ compared against themselves.
525
+
526
+ Args:
527
+ battle_counts_data_raw (List[List[str]]): Raw battle count data from the database,
528
+ where each inner list contains [comparison_type, count].
529
+
530
+ Returns:
531
+ List[List[str]]: HTML-formatted matrix of battle counts between providers.
532
+ """
533
+ battle_counts_dict = {item[0]: item[1] for item in battle_counts_data_raw}
534
+ # Create canonical comparison keys based on your expected database formats
535
+ comparison_keys = {
536
+ ("Hume AI", "OpenAI"): "Hume AI - OpenAI",
537
+ ("Hume AI", "ElevenLabs"): "Hume AI - ElevenLabs",
538
+ ("OpenAI", "ElevenLabs"): "OpenAI - ElevenLabs"
539
+ }
540
+ return [
541
+ [
542
+ f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
543
+ ] + [
544
+ f"""
545
+ <p style="text-align: center;">
546
+ {"-" if row_provider == col_provider
547
+ else battle_counts_dict.get(
548
+ comparison_keys.get((row_provider, col_provider)) or
549
+ comparison_keys.get((col_provider, row_provider), "unknown"),
550
+ "0"
551
+ )
552
+ }
553
+ </p>
554
+ """ for col_provider in constants.TTS_PROVIDERS
555
+ ]
556
+ for row_provider in constants.TTS_PROVIDERS
557
+ ]
558
+
559
+
560
+ def _format_win_rate_data(win_rate_data_raw: List[List[str]]) -> List[List[str]]:
561
+ """
562
+ Formats win rate data into a matrix format for the UI.
563
+
564
+ Creates a provider-by-provider matrix showing the percentage of times the row
565
+ provider won against the column provider. Diagonal cells show dashes as
566
+ providers aren't compared against themselves.
567
+
568
+ Args:
569
+ win_rate_data_raw (List[List[str]]): Raw win rate data from the database,
570
+ where each inner list contains [comparison_type, first_win_rate, second_win_rate].
571
+
572
+ Returns:
573
+ List[List[str]]: HTML-formatted matrix of win rates between providers.
574
+ """
575
+ # Create a clean lookup dictionary with provider pairs as keys
576
+ win_rates = {}
577
+ for comparison_type, first_win_rate, second_win_rate in win_rate_data_raw:
578
+ provider1, provider2 = comparison_type.split(" - ")
579
+ win_rates[(provider1, provider2)] = first_win_rate
580
+ win_rates[(provider2, provider1)] = second_win_rate
581
+
582
+ return [
583
+ [
584
+ f'<p style="padding-left: 8px;"><strong>{row_provider}</strong></p>'
585
+ ] + [
586
+ f"""
587
+ <p style="text-align: center;">
588
+ {"-" if row_provider == col_provider else win_rates.get((row_provider, col_provider), "0%")}
589
+ </p>
590
+ """
591
+ for col_provider in constants.TTS_PROVIDERS
592
+ ]
593
+ for row_provider in constants.TTS_PROVIDERS
594
+ ]
595
+
596
 
597
  def validate_env_var(var_name: str) -> str:
598
  """