zach commited on
Commit
04e2d2a
·
1 Parent(s): ad1ff58

Fix types in app.py

Browse files
Files changed (2) hide show
  1. pyproject.toml +1 -0
  2. src/app.py +47 -50
pyproject.toml CHANGED
@@ -40,6 +40,7 @@ ignore = [
40
  "EM102",
41
  "FIX002",
42
  "G004",
 
43
  "PLR0913",
44
  "PLR2004",
45
  "TD002",
 
40
  "EM102",
41
  "FIX002",
42
  "G004",
43
+ "PLR0912",
44
  "PLR0913",
45
  "PLR2004",
46
  "TD002",
src/app.py CHANGED
@@ -11,7 +11,7 @@ Users can compare the outputs and vote for their favorite in an interactive UI.
11
  # Standard Library Imports
12
  import time
13
  from concurrent.futures import ThreadPoolExecutor
14
- from typing import Tuple, Union
15
 
16
  # Third-Party Library Imports
17
  import gradio as gr
@@ -19,7 +19,7 @@ import gradio as gr
19
  # Local Application Imports
20
  from src import constants
21
  from src.config import Config, logger
22
- from src.custom_types import ComparisonType, Option, OptionMap
23
  from src.database.database import DBSessionMaker
24
  from src.integrations import (
25
  AnthropicError,
@@ -50,7 +50,7 @@ class App:
50
  def _generate_text(
51
  self,
52
  character_description: str,
53
- ) -> Tuple[Union[str, gr.update], gr.update]:
54
  """
55
  Validates the character_description and generates text using Anthropic API.
56
 
@@ -59,13 +59,12 @@ class App:
59
 
60
  Returns:
61
  Tuple containing:
62
- - The generated text (as a gr.update).
63
- - An update for the generated text state.
64
 
65
  Raises:
66
  gr.Error: On validation or API errors.
67
  """
68
-
69
  try:
70
  validate_character_description_length(character_description)
71
  except ValueError as ve:
@@ -88,7 +87,7 @@ class App:
88
  character_description: str,
89
  text: str,
90
  generated_text_state: str,
91
- ) -> Tuple[gr.update, gr.update, dict, str, ComparisonType, str, str, bool, str, str]:
92
  """
93
  Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
94
 
@@ -98,7 +97,7 @@ class App:
98
  - Synthesize two Hume outputs (50% chance).
99
 
100
  The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
101
- Additional metadata such as the generation IDs, comparison type, and state information are also returned.
102
 
103
  Args:
104
  character_description (str): The description of the character used for generating the voice.
@@ -108,13 +107,9 @@ class App:
108
 
109
  Returns:
110
  Tuple containing:
111
- - gr.update: Update for the first audio player (with autoplay enabled).
112
- - gr.update: Update for the second audio player.
113
- - dict: A mapping of option constants to their corresponding TTS providers.
114
- - str: The raw audio value (relative file path) for option B.
115
- - ComparisonType: The comparison type between the selected TTS providers.
116
- - str: Generation ID for option A.
117
- - str: Generation ID for option B.
118
  - bool: Flag indicating whether the text was modified.
119
  - str: The original text that was synthesized.
120
  - str: The original character description.
@@ -122,7 +117,6 @@ class App:
122
  Raises:
123
  gr.Error: If any API or unexpected errors occur during the TTS synthesis process.
124
  """
125
-
126
  if not text:
127
  logger.warning("Skipping text-to-speech due to empty text.")
128
  raise gr.Error("Please generate or enter text to synthesize.")
@@ -134,34 +128,41 @@ class App:
134
  try:
135
  if provider_b == constants.HUME_AI:
136
  num_generations = 2
137
- # If generating 2 Hume outputs, do so in a single API call
138
- (
139
- generation_id_a,
140
- audio_a,
141
- generation_id_b,
142
- audio_b,
143
- ) = text_to_speech_with_hume(character_description, text, num_generations, self.config)
144
  else:
145
  with ThreadPoolExecutor(max_workers=2) as executor:
146
  num_generations = 1
147
- # Generate a single Hume output
148
  future_audio_a = executor.submit(
149
  text_to_speech_with_hume, character_description, text, num_generations, self.config
150
  )
151
- # Generate a second TTS output from the second provider
152
  match provider_b:
153
  case constants.ELEVENLABS:
154
  future_audio_b = executor.submit(
155
  text_to_speech_with_elevenlabs, character_description, text, self.config
156
  )
157
  case _:
158
- # Additional TTS Providers can be added here
159
  raise ValueError(f"Unsupported provider: {provider_b}")
160
 
161
- generation_id_a, audio_a = future_audio_a.result()
162
- generation_id_b, audio_b = future_audio_b.result()
163
-
164
- # Shuffle options so that placement of options in the UI will always be random
 
 
 
 
 
 
 
 
165
  option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
166
  option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
167
  options_map: OptionMap = create_shuffled_tts_options(option_a, option_b)
@@ -185,7 +186,7 @@ class App:
185
  raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
186
  except Exception as e:
187
  logger.error(f"Unexpected error during TTS generation: {e}")
188
- raise gr.Error("An unexpected error ocurred. Please try again later.")
189
 
190
  def _vote(
191
  self,
@@ -195,7 +196,7 @@ class App:
195
  text_modified: bool,
196
  character_description: str,
197
  text: str,
198
- ) -> Tuple[bool, gr.update, gr.update, gr.update]:
199
  """
200
  Handles user voting.
201
 
@@ -207,16 +208,15 @@ class App:
207
  'Option A': 'Hume AI',
208
  'Option B': 'ElevenLabs',
209
  }
210
- selected_button (str): The button that was clicked.
211
 
212
  Returns:
213
  A tuple of:
214
  - A boolean indicating if the vote was accepted.
215
- - An update for the selected vote button (showing provider and trophy emoji).
216
- - An update for the unselected vote button (showing provider).
217
- - An update for enabling vote interactions.
218
  """
219
-
220
  if not option_map or vote_submitted:
221
  return gr.skip(), gr.skip(), gr.skip(), gr.skip()
222
 
@@ -224,7 +224,7 @@ class App:
224
  selected_provider = option_map[selected_option]["provider"]
225
  other_provider = option_map[other_option]["provider"]
226
 
227
- # Report voting results to be persisted to results DB
228
  submit_voting_results(
229
  option_map,
230
  selected_option,
@@ -254,7 +254,7 @@ class App:
254
  gr.update(interactive=True),
255
  )
256
 
257
- def _reset_ui(self) -> Tuple[gr.update, gr.update, gr.update, gr.update, None, bool]:
258
  """
259
  Resets UI state before generating new text.
260
 
@@ -263,17 +263,20 @@ class App:
263
  - option_a_audio_player (clear audio)
264
  - option_b_audio_player (clear audio)
265
  - vote_button_a (disable and reset button text)
266
- - vote_button_a (disable and reset button text)
267
  - option_map_state (reset option map state)
268
  - vote_submitted_state (reset submitted vote state)
269
  """
270
-
 
 
 
271
  return (
272
  gr.update(value=None),
273
  gr.update(value=None, autoplay=False),
274
  gr.update(value=constants.SELECT_OPTION_A, variant="secondary"),
275
  gr.update(value=constants.SELECT_OPTION_B, variant="secondary"),
276
- None,
277
  False,
278
  )
279
 
@@ -282,7 +285,6 @@ class App:
282
  Builds the input section including the sample character description dropdown, character
283
  description input, and generate text button.
284
  """
285
-
286
  sample_character_description_dropdown = gr.Dropdown(
287
  choices=list(constants.SAMPLE_CHARACTER_DESCRIPTIONS.keys()),
288
  label="Choose a sample character description",
@@ -308,7 +310,6 @@ class App:
308
  """
309
  Builds the output section including text input, audio players, and vote buttons.
310
  """
311
-
312
  text_input = gr.Textbox(
313
  label="Input Text",
314
  placeholder="Enter or generate text for synthesis...",
@@ -342,7 +343,6 @@ class App:
342
  Returns:
343
  gr.Blocks: The fully constructed Gradio UI layout.
344
  """
345
-
346
  custom_theme = CustomTheme()
347
  with gr.Blocks(
348
  title="Expressive TTS Arena",
@@ -384,7 +384,6 @@ class App:
384
  ) = self._build_output_section()
385
 
386
  # --- UI state components ---
387
-
388
  # Track character description used for text and voice generation
389
  character_description_state = gr.State("")
390
  # Track text used for speech synthesis
@@ -393,10 +392,8 @@ class App:
393
  generated_text_state = gr.State("")
394
  # Track whether text that was used was generated or modified/custom
395
  text_modified_state = gr.State()
396
-
397
  # Track option map (option A and option B are randomized)
398
- option_map_state = gr.State()
399
-
400
  # Track whether the user has voted for an option
401
  vote_submitted_state = gr.State(False)
402
 
@@ -506,7 +503,7 @@ class App:
506
  inputs=[],
507
  outputs=[vote_button_a, vote_button_b],
508
  ).then(
509
- fn=self.vote,
510
  inputs=[
511
  vote_submitted_state,
512
  option_map_state,
 
11
  # Standard Library Imports
12
  import time
13
  from concurrent.futures import ThreadPoolExecutor
14
+ from typing import Tuple
15
 
16
  # Third-Party Library Imports
17
  import gradio as gr
 
19
  # Local Application Imports
20
  from src import constants
21
  from src.config import Config, logger
22
+ from src.custom_types import Option, OptionMap
23
  from src.database.database import DBSessionMaker
24
  from src.integrations import (
25
  AnthropicError,
 
50
  def _generate_text(
51
  self,
52
  character_description: str,
53
+ ) -> Tuple[dict, str]:
54
  """
55
  Validates the character_description and generates text using Anthropic API.
56
 
 
59
 
60
  Returns:
61
  Tuple containing:
62
+ - The generated text update (as a dict from gr.update).
63
+ - The generated text string.
64
 
65
  Raises:
66
  gr.Error: On validation or API errors.
67
  """
 
68
  try:
69
  validate_character_description_length(character_description)
70
  except ValueError as ve:
 
87
  character_description: str,
88
  text: str,
89
  generated_text_state: str,
90
+ ) -> Tuple[dict, dict, OptionMap, bool, str, str]:
91
  """
92
  Synthesizes two text-to-speech outputs, updates UI state components, and returns additional TTS metadata.
93
 
 
97
  - Synthesize two Hume outputs (50% chance).
98
 
99
  The outputs are processed and shuffled, and the corresponding UI components for two audio players are updated.
100
+ Additional metadata such as the comparison type, generation IDs, and state information are also returned.
101
 
102
  Args:
103
  character_description (str): The description of the character used for generating the voice.
 
107
 
108
  Returns:
109
  Tuple containing:
110
+ - dict: Update for the first audio player (with autoplay enabled).
111
+ - dict: Update for the second audio player.
112
+ - OptionMap: A mapping of option constants to their corresponding TTS providers.
 
 
 
 
113
  - bool: Flag indicating whether the text was modified.
114
  - str: The original text that was synthesized.
115
  - str: The original character description.
 
117
  Raises:
118
  gr.Error: If any API or unexpected errors occur during the TTS synthesis process.
119
  """
 
120
  if not text:
121
  logger.warning("Skipping text-to-speech due to empty text.")
122
  raise gr.Error("Please generate or enter text to synthesize.")
 
128
  try:
129
  if provider_b == constants.HUME_AI:
130
  num_generations = 2
131
+ # If generating 2 Hume outputs, do so in a single API call.
132
+ result = text_to_speech_with_hume(character_description, text, num_generations, self.config)
133
+ # Enforce that 4 values are returned.
134
+ if not (isinstance(result, tuple) and len(result) == 4):
135
+ raise ValueError("Expected 4 values from Hume TTS call when generating 2 outputs")
136
+ generation_id_a, audio_a, generation_id_b, audio_b = result
 
137
  else:
138
  with ThreadPoolExecutor(max_workers=2) as executor:
139
  num_generations = 1
140
+ # Generate a single Hume output.
141
  future_audio_a = executor.submit(
142
  text_to_speech_with_hume, character_description, text, num_generations, self.config
143
  )
144
+ # Generate a second TTS output from the second provider.
145
  match provider_b:
146
  case constants.ELEVENLABS:
147
  future_audio_b = executor.submit(
148
  text_to_speech_with_elevenlabs, character_description, text, self.config
149
  )
150
  case _:
151
+ # Additional TTS Providers can be added here.
152
  raise ValueError(f"Unsupported provider: {provider_b}")
153
 
154
+ result_a = future_audio_a.result()
155
+ result_b = future_audio_b.result()
156
+ if isinstance(result_a, tuple) and len(result_a) >= 2:
157
+ generation_id_a, audio_a = result_a[0], result_a[1]
158
+ else:
159
+ raise ValueError("Unexpected return from text_to_speech_with_hume")
160
+ if isinstance(result_b, tuple) and len(result_b) >= 2:
161
+ generation_id_b, audio_b = result_b[0], result_b[1] # type: ignore
162
+ else:
163
+ raise ValueError("Unexpected return from text_to_speech_with_elevenlabs")
164
+
165
+ # Shuffle options so that placement of options in the UI will always be random.
166
  option_a = Option(provider=provider_a, audio=audio_a, generation_id=generation_id_a)
167
  option_b = Option(provider=provider_b, audio=audio_b, generation_id=generation_id_b)
168
  options_map: OptionMap = create_shuffled_tts_options(option_a, option_b)
 
186
  raise gr.Error(f'There was an issue communicating with the Hume API: "{he.message}"')
187
  except Exception as e:
188
  logger.error(f"Unexpected error during TTS generation: {e}")
189
+ raise gr.Error("An unexpected error occurred. Please try again later.")
190
 
191
  def _vote(
192
  self,
 
196
  text_modified: bool,
197
  character_description: str,
198
  text: str,
199
+ ) -> Tuple[bool, dict, dict, dict]:
200
  """
201
  Handles user voting.
202
 
 
208
  'Option A': 'Hume AI',
209
  'Option B': 'ElevenLabs',
210
  }
211
+ clicked_option_button (str): The button that was clicked.
212
 
213
  Returns:
214
  A tuple of:
215
  - A boolean indicating if the vote was accepted.
216
+ - A dict update for the selected vote button (showing provider and trophy emoji).
217
+ - A dict update for the unselected vote button (showing provider).
218
+ - A dict update for enabling vote interactions.
219
  """
 
220
  if not option_map or vote_submitted:
221
  return gr.skip(), gr.skip(), gr.skip(), gr.skip()
222
 
 
224
  selected_provider = option_map[selected_option]["provider"]
225
  other_provider = option_map[other_option]["provider"]
226
 
227
+ # Report voting results to be persisted to results DB.
228
  submit_voting_results(
229
  option_map,
230
  selected_option,
 
254
  gr.update(interactive=True),
255
  )
256
 
257
+ def _reset_ui(self) -> Tuple[dict, dict, dict, dict, OptionMap, bool]:
258
  """
259
  Resets UI state before generating new text.
260
 
 
263
  - option_a_audio_player (clear audio)
264
  - option_b_audio_player (clear audio)
265
  - vote_button_a (disable and reset button text)
266
+ - vote_button_b (disable and reset button text)
267
  - option_map_state (reset option map state)
268
  - vote_submitted_state (reset submitted vote state)
269
  """
270
+ default_option_map: OptionMap = {
271
+ "option_a": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
272
+ "option_b": {"provider": constants.HUME_AI, "generation_id": None, "audio_file_path": ""},
273
+ }
274
  return (
275
  gr.update(value=None),
276
  gr.update(value=None, autoplay=False),
277
  gr.update(value=constants.SELECT_OPTION_A, variant="secondary"),
278
  gr.update(value=constants.SELECT_OPTION_B, variant="secondary"),
279
+ default_option_map, # Reset option_map_state as a default OptionMap
280
  False,
281
  )
282
 
 
285
  Builds the input section including the sample character description dropdown, character
286
  description input, and generate text button.
287
  """
 
288
  sample_character_description_dropdown = gr.Dropdown(
289
  choices=list(constants.SAMPLE_CHARACTER_DESCRIPTIONS.keys()),
290
  label="Choose a sample character description",
 
310
  """
311
  Builds the output section including text input, audio players, and vote buttons.
312
  """
 
313
  text_input = gr.Textbox(
314
  label="Input Text",
315
  placeholder="Enter or generate text for synthesis...",
 
343
  Returns:
344
  gr.Blocks: The fully constructed Gradio UI layout.
345
  """
 
346
  custom_theme = CustomTheme()
347
  with gr.Blocks(
348
  title="Expressive TTS Arena",
 
384
  ) = self._build_output_section()
385
 
386
  # --- UI state components ---
 
387
  # Track character description used for text and voice generation
388
  character_description_state = gr.State("")
389
  # Track text used for speech synthesis
 
392
  generated_text_state = gr.State("")
393
  # Track whether text that was used was generated or modified/custom
394
  text_modified_state = gr.State()
 
395
  # Track option map (option A and option B are randomized)
396
+ option_map_state = gr.State({}) # OptionMap state as a dictionary
 
397
  # Track whether the user has voted for an option
398
  vote_submitted_state = gr.State(False)
399
 
 
503
  inputs=[],
504
  outputs=[vote_button_a, vote_button_b],
505
  ).then(
506
+ fn=self._vote,
507
  inputs=[
508
  vote_submitted_state,
509
  option_map_state,