buttercrab commited on
Commit
26268b4
·
unverified ·
1 Parent(s): b9c8c2a
Files changed (1) hide show
  1. app.py +30 -3
app.py CHANGED
@@ -26,6 +26,7 @@ except Exception as e:
26
  def run_inference(
27
  text_input: str,
28
  audio_prompt_input: Optional[Tuple[int, np.ndarray]],
 
29
  max_new_tokens: int,
30
  cfg_scale: float,
31
  temperature: float,
@@ -50,6 +51,10 @@ def run_inference(
50
  prompt_path_for_generate = None
51
  if audio_prompt_input is not None:
52
  sr, audio_data = audio_prompt_input
 
 
 
 
53
  # Check if audio_data is valid
54
  if (
55
  audio_data is None or audio_data.size == 0 or audio_data.max() == 0
@@ -117,8 +122,15 @@ def run_inference(
117
 
118
  # Use torch.inference_mode() context manager for the generation call
119
  with torch.inference_mode():
 
 
 
 
 
 
 
120
  output_audio_np = model.generate(
121
- text_input,
122
  max_tokens=max_new_tokens,
123
  cfg_scale=cfg_scale,
124
  temperature=temperature,
@@ -242,11 +254,16 @@ with gr.Blocks(css=css) as demo:
242
  lines=5, # Increased lines
243
  )
244
  audio_prompt_input = gr.Audio(
245
- label="Audio Prompt (Optional)",
246
  show_label=True,
247
  sources=["upload", "microphone"],
248
  type="numpy",
249
  )
 
 
 
 
 
250
  with gr.Accordion("Generation Parameters", open=False):
251
  max_new_tokens = gr.Slider(
252
  label="Max New Tokens (Audio Length)",
@@ -312,6 +329,7 @@ with gr.Blocks(css=css) as demo:
312
  inputs=[
313
  text_input,
314
  audio_prompt_input,
 
315
  max_new_tokens,
316
  cfg_scale,
317
  temperature,
@@ -350,10 +368,19 @@ with gr.Blocks(css=css) as demo:
350
 
351
  if examples_list:
352
  gr.Examples(
353
- examples=examples_list,
 
 
 
 
 
 
 
 
354
  inputs=[
355
  text_input,
356
  audio_prompt_input,
 
357
  max_new_tokens,
358
  cfg_scale,
359
  temperature,
 
26
  def run_inference(
27
  text_input: str,
28
  audio_prompt_input: Optional[Tuple[int, np.ndarray]],
29
+ transcription_input: Optional[str],
30
  max_new_tokens: int,
31
  cfg_scale: float,
32
  temperature: float,
 
51
  prompt_path_for_generate = None
52
  if audio_prompt_input is not None:
53
  sr, audio_data = audio_prompt_input
54
+ # Enforce maximum duration of 10 seconds for the audio prompt
55
+ duration_sec = len(audio_data) / float(sr) if sr else 0
56
+ if duration_sec > 10.0:
57
+ raise gr.Error("Audio prompt must be 10 seconds or shorter.")
58
  # Check if audio_data is valid
59
  if (
60
  audio_data is None or audio_data.size == 0 or audio_data.max() == 0
 
122
 
123
  # Use torch.inference_mode() context manager for the generation call
124
  with torch.inference_mode():
125
+ # Concatenate transcription (if provided) to the main text
126
+ combined_text = (
127
+ text_input.strip() + "\n" + transcription_input.strip()
128
+ if transcription_input and not transcription_input.isspace()
129
+ else text_input
130
+ )
131
+
132
  output_audio_np = model.generate(
133
+ combined_text,
134
  max_tokens=max_new_tokens,
135
  cfg_scale=cfg_scale,
136
  temperature=temperature,
 
254
  lines=5, # Increased lines
255
  )
256
  audio_prompt_input = gr.Audio(
257
+ label="Audio Prompt (≤ 10 s, Optional)",
258
  show_label=True,
259
  sources=["upload", "microphone"],
260
  type="numpy",
261
  )
262
+ transcription_input = gr.Textbox(
263
+ label="Audio Prompt Transcription (Optional)",
264
+ placeholder="Enter transcription of your audio prompt here...",
265
+ lines=3,
266
+ )
267
  with gr.Accordion("Generation Parameters", open=False):
268
  max_new_tokens = gr.Slider(
269
  label="Max New Tokens (Audio Length)",
 
329
  inputs=[
330
  text_input,
331
  audio_prompt_input,
332
+ transcription_input,
333
  max_new_tokens,
334
  cfg_scale,
335
  temperature,
 
368
 
369
  if examples_list:
370
  gr.Examples(
371
+ examples=[
372
+ [
373
+ ex[0], # text
374
+ ex[1], # audio prompt path
375
+ "", # transcription placeholder
376
+ *ex[2:],
377
+ ]
378
+ for ex in examples_list
379
+ ],
380
  inputs=[
381
  text_input,
382
  audio_prompt_input,
383
+ transcription_input,
384
  max_new_tokens,
385
  cfg_scale,
386
  temperature,