yasserrmd commited on
Commit
bdbe728
Β·
verified Β·
1 Parent(s): 3d52831

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -119
app.py CHANGED
@@ -22,27 +22,54 @@ logger = logging.get_logger(__name__)
22
 
23
 
24
  class VibeVoiceDemo:
25
- def __init__(self, model_path: str, device: str = "cuda", inference_steps: int = 5):
26
- self.model_path = model_path
 
 
 
 
27
  self.device = device
28
  self.inference_steps = inference_steps
 
29
  self.is_generating = False
30
- self.processor = None
31
- self.model = None
 
 
 
 
32
  self.available_voices = {}
33
- self.load_model()
 
34
  self.setup_voice_presets()
35
  self.load_example_scripts()
36
 
37
- def load_model(self):
38
- print(f"Loading processor & model from {self.model_path}")
39
- self.processor = VibeVoiceProcessor.from_pretrained(self.model_path)
40
- self.model = VibeVoiceForConditionalGenerationInference.from_pretrained(
41
- self.model_path,
42
- torch_dtype=torch.bfloat16
43
- )
44
- # self.model.eval()
45
- # self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def setup_voice_presets(self):
48
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
@@ -69,153 +96,136 @@ class VibeVoiceDemo:
69
  return np.array([])
70
 
71
  @GPU(duration=60)
72
- def generate_podcast(self,
73
- num_speakers: int,
74
- script: str,
75
- speaker_1: str = None,
76
- speaker_2: str = None,
77
- speaker_3: str = None,
78
- speaker_4: str = None,
79
- cfg_scale: float = 1.3):
 
80
  """
81
  Generates a podcast as a single audio file from a script and saves it.
82
- This is a non-streaming function.
83
  """
84
  try:
85
- self.model = self.model.to(self.device)
86
-
87
- print(f"Model successfully moved to device: {self.device.upper()}")
88
-
89
- # Step 3: Continue with the rest of your setup.
90
- self.model.eval()
91
- self.model.set_ddpm_inference_steps(num_steps=self.inference_steps)
92
- # 1. Set generating state and validate inputs
 
 
 
 
 
 
 
93
  self.is_generating = True
94
-
95
  if not script.strip():
96
  raise gr.Error("Error: Please provide a script.")
97
-
98
- # Defend against common mistake with apostrophes
99
  script = script.replace("’", "'")
100
-
101
  if not 1 <= num_speakers <= 4:
102
  raise gr.Error("Error: Number of speakers must be between 1 and 4.")
103
-
104
- # 2. Collect and validate selected speakers
105
  selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
106
  for i, speaker_name in enumerate(selected_speakers):
107
  if not speaker_name or speaker_name not in self.available_voices:
108
  raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.")
109
-
110
- # 3. Build initial log
111
  log = f"πŸŽ™οΈ Generating podcast with {num_speakers} speakers\n"
 
112
  log += f"πŸ“Š Parameters: CFG Scale={cfg_scale}\n"
113
  log += f"🎭 Speakers: {', '.join(selected_speakers)}\n"
114
-
115
- # 4. Load voice samples
116
  voice_samples = []
117
  for speaker_name in selected_speakers:
118
  audio_path = self.available_voices[speaker_name]
119
- # Assuming self.read_audio is a method in your class that returns audio data
120
  audio_data = self.read_audio(audio_path)
121
  if len(audio_data) == 0:
122
  raise gr.Error(f"Error: Failed to load audio for {speaker_name}")
123
  voice_samples.append(audio_data)
124
-
125
  log += f"βœ… Loaded {len(voice_samples)} voice samples\n"
126
-
127
- # 5. Parse and format the script
128
  lines = script.strip().split('\n')
129
  formatted_script_lines = []
130
  for line in lines:
131
  line = line.strip()
132
  if not line:
133
  continue
134
-
135
- # Check if line already has speaker format (e.g., "Speaker 1: ...")
136
  if line.startswith('Speaker ') and ':' in line:
137
  formatted_script_lines.append(line)
138
  else:
139
- # Auto-assign speakers in rotation
140
  speaker_id = len(formatted_script_lines) % num_speakers
141
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
142
-
143
  formatted_script = '\n'.join(formatted_script_lines)
144
  log += f"πŸ“ Formatted script with {len(formatted_script_lines)} turns\n"
145
  log += "πŸ”„ Processing with VibeVoice...\n"
146
-
147
- # 6. Prepare inputs for the model
148
- # Assuming self.processor is an object available in your class
149
- inputs = self.processor(
150
  text=[formatted_script],
151
  voice_samples=[voice_samples],
152
  padding=True,
153
  return_tensors="pt",
154
  return_attention_mask=True,
155
  )
156
-
157
- # 7. Generate audio
158
  start_time = time.time()
159
- # Assuming self.model is an object available in your class
160
- outputs = self.model.generate(
161
  **inputs,
162
  max_new_tokens=None,
163
  cfg_scale=cfg_scale,
164
- tokenizer=self.processor.tokenizer,
165
  generation_config={'do_sample': False},
166
- verbose=False, # Verbose is off for cleaner logs
167
  )
168
  generation_time = time.time() - start_time
169
-
170
- # 8. Extract audio output
171
- # The generated audio is often in speech_outputs or a similar attribute
172
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
173
  audio_tensor = outputs.speech_outputs[0]
174
  audio = audio_tensor.cpu().float().numpy()
175
  else:
176
  raise gr.Error("❌ Error: No audio was generated by the model. Please try again.")
177
-
178
- # Ensure audio is a 1D array
179
  if audio.ndim > 1:
180
  audio = audio.squeeze()
181
-
182
- sample_rate = 24000 # Standard sample rate for this model
183
-
184
- # 9. Save the audio file
185
  output_dir = "outputs"
186
  os.makedirs(output_dir, exist_ok=True)
187
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
188
  file_path = os.path.join(output_dir, f"podcast_{timestamp}.wav")
189
-
190
- # Write the NumPy array to a WAV file
191
  sf.write(file_path, audio, sample_rate)
192
  print(f"πŸ’Ύ Podcast saved to {file_path}")
193
-
194
- # 10. Finalize log and return
195
  total_duration = len(audio) / sample_rate
196
  log += f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
197
  log += f"🎡 Final audio duration: {total_duration:.2f} seconds\n"
198
  log += f"βœ… Successfully saved podcast to: {file_path}\n"
199
-
200
  self.is_generating = False
201
  return (sample_rate, audio), log
202
 
203
  except gr.Error as e:
204
- # Handle Gradio-specific errors (for user feedback)
205
  self.is_generating = False
206
  error_msg = f"❌ Input Error: {str(e)}"
207
  print(error_msg)
208
- # In Gradio, you would typically return an update to the UI
209
- # For a pure function, we re-raise or handle it as needed.
210
- # This return signature matches the success case but with error info.
211
  return None, error_msg
212
-
213
  except Exception as e:
214
- # Handle all other unexpected errors
215
  self.is_generating = False
216
  error_msg = f"❌ An unexpected error occurred: {str(e)}"
217
  print(error_msg)
218
- import traceback
219
  traceback.print_exc()
220
  return None, error_msg
221
 
@@ -223,20 +233,55 @@ class VibeVoiceDemo:
223
 
224
 
225
  def load_example_scripts(self):
 
226
  examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
227
  self.example_scripts = []
 
 
228
  if not os.path.exists(examples_dir):
 
229
  return
230
- txt_files = sorted([f for f in os.listdir(examples_dir)
231
- if f.lower().endswith('.txt')])
 
 
 
232
  for txt_file in txt_files:
 
 
 
 
 
 
 
 
 
 
 
233
  try:
234
- with open(os.path.join(examples_dir, txt_file), 'r', encoding='utf-8') as f:
235
  script_content = f.read().strip()
236
- if script_content:
237
- self.example_scripts.append([1, script_content])
 
 
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
- print(f"Error loading {txt_file}: {e}")
 
 
 
 
 
240
 
241
 
242
  def convert_to_16_bit_wav(data):
@@ -249,10 +294,202 @@ def convert_to_16_bit_wav(data):
249
 
250
 
251
  def create_demo_interface(demo_instance: VibeVoiceDemo):
252
- """Create the Gradio interface (final audio only, no streaming)."""
253
-
254
- # Custom CSS for high-end aesthetics
255
- custom_css = """ ... """ # (keep your CSS unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
  with gr.Blocks(
258
  title="VibeVoice - AI Podcast Generator",
@@ -263,27 +500,32 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
263
  neutral_hue="slate",
264
  )
265
  ) as interface:
266
-
267
- # Header
268
  gr.HTML("""
269
  <div class="main-header">
270
  <h1>πŸŽ™οΈ Vibe Podcasting</h1>
271
  <p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p>
272
  </div>
273
  """)
274
-
275
  with gr.Row():
276
- # Left column - Settings
277
  with gr.Column(scale=1, elem_classes="settings-card"):
278
- gr.Markdown("### πŸŽ›οΈ **Podcast Settings**")
279
-
 
 
 
 
 
 
 
280
  num_speakers = gr.Slider(
281
  minimum=1, maximum=4, value=2, step=1,
282
  label="Number of Speakers",
283
  elem_classes="slider-container"
284
  )
285
-
286
- gr.Markdown("### 🎭 **Speaker Selection**")
287
  available_speaker_names = list(demo_instance.available_voices.keys())
288
  default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman']
289
 
@@ -298,18 +540,17 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
298
  elem_classes="speaker-item"
299
  )
300
  speaker_selections.append(speaker)
301
-
302
- gr.Markdown("### βš™οΈ **Advanced Settings**")
303
  with gr.Accordion("Generation Parameters", open=False):
304
  cfg_scale = gr.Slider(
305
  minimum=1.0, maximum=2.0, value=1.3, step=0.05,
306
  label="CFG Scale (Guidance Strength)",
307
  elem_classes="slider-container"
308
  )
309
-
310
- # Right column - Generation
311
  with gr.Column(scale=2, elem_classes="generation-card"):
312
- gr.Markdown("### πŸ“ **Script Input**")
313
  script_input = gr.Textbox(
314
  label="Conversation Script",
315
  placeholder="Enter your podcast script here...",
@@ -317,7 +558,7 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
317
  max_lines=20,
318
  elem_classes="script-input"
319
  )
320
-
321
  with gr.Row():
322
  random_example_btn = gr.Button(
323
  "🎲 Random Example", size="lg",
@@ -327,9 +568,8 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
327
  "πŸš€ Generate Podcast", size="lg",
328
  variant="primary", elem_classes="generate-btn", scale=2
329
  )
330
-
331
- # Output section
332
- gr.Markdown("### 🎡 **Generated Podcast**")
333
  complete_audio_output = gr.Audio(
334
  label="Complete Podcast (Download)",
335
  type="numpy",
@@ -338,28 +578,27 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
338
  show_download_button=True,
339
  visible=True
340
  )
341
-
342
  log_output = gr.Textbox(
343
  label="Generation Log",
344
  lines=8, max_lines=15,
345
  interactive=False,
346
  elem_classes="log-output"
347
  )
348
-
349
- # === logic ===
350
  def update_speaker_visibility(num_speakers):
351
  return [gr.update(visible=(i < num_speakers)) for i in range(4)]
352
-
353
  num_speakers.change(
354
  fn=update_speaker_visibility,
355
  inputs=[num_speakers],
356
  outputs=speaker_selections
357
  )
358
 
359
- def generate_podcast_wrapper(num_speakers, script, *speakers_and_params):
360
  try:
361
  speakers = speakers_and_params[:4]
362
- cfg_scale = speakers_and_params[4]
363
  audio, log = demo_instance.generate_podcast(
364
  num_speakers=int(num_speakers),
365
  script=script,
@@ -367,7 +606,8 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
367
  speaker_2=speakers[1],
368
  speaker_3=speakers[2],
369
  speaker_4=speakers[3],
370
- cfg_scale=cfg_scale
 
371
  )
372
  return audio, log
373
  except Exception as e:
@@ -376,7 +616,7 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
376
 
377
  generate_btn.click(
378
  fn=generate_podcast_wrapper,
379
- inputs=[num_speakers, script_input] + speaker_selections + [cfg_scale],
380
  outputs=[complete_audio_output, log_output],
381
  queue=True
382
  )
@@ -397,8 +637,8 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
397
  outputs=[num_speakers, script_input],
398
  queue=False
399
  )
400
-
401
- gr.Markdown("### πŸ“š **Example Scripts**")
402
  examples = getattr(demo_instance, "example_scripts", []) or [
403
  [1, "Speaker 1: Welcome to our AI podcast demo. This is a sample script."]
404
  ]
@@ -412,14 +652,24 @@ def create_demo_interface(demo_instance: VibeVoiceDemo):
412
 
413
 
414
 
 
415
  def run_demo(
416
- model_path: str = "microsoft/VibeVoice-1.5B",
417
  device: str = "cuda",
418
  inference_steps: int = 5,
419
  share: bool = True,
420
  ):
 
 
 
 
 
 
 
 
 
421
  set_seed(42)
422
- demo_instance = VibeVoiceDemo(model_path, device, inference_steps)
423
  interface = create_demo_interface(demo_instance)
424
  interface.queue().launch(
425
  share=share,
@@ -429,5 +679,6 @@ def run_demo(
429
  )
430
 
431
 
 
432
  if __name__ == "__main__":
433
  run_demo()
 
22
 
23
 
24
  class VibeVoiceDemo:
25
+ def __init__(self, model_paths: dict, device: str = "cuda", inference_steps: int = 5):
26
+ """
27
+ model_paths: dict like {"VibeVoice-1.5B": "microsoft/VibeVoice-1.5B",
28
+ "VibeVoice-1.1B": "microsoft/VibeVoice-1.1B"}
29
+ """
30
+ self.model_paths = model_paths
31
  self.device = device
32
  self.inference_steps = inference_steps
33
+
34
  self.is_generating = False
35
+
36
+ # Multi-model holders
37
+ self.models = {} # name -> model
38
+ self.processors = {} # name -> processor
39
+ self.current_model_name = None
40
+
41
  self.available_voices = {}
42
+
43
+ self.load_models() # load all on CPU
44
  self.setup_voice_presets()
45
  self.load_example_scripts()
46
 
47
+ def load_models(self):
48
+ print("Loading processors and models on CPU...")
49
+ for name, path in self.model_paths.items():
50
+ print(f" - {name} from {path}")
51
+ proc = VibeVoiceProcessor.from_pretrained(path)
52
+ mdl = VibeVoiceForConditionalGenerationInference.from_pretrained(
53
+ path, torch_dtype=torch.bfloat16
54
+ )
55
+ # Keep on CPU initially
56
+ self.processors[name] = proc
57
+ self.models[name] = mdl
58
+ # choose default
59
+ self.current_model_name = next(iter(self.models))
60
+ print(f"Default model is {self.current_model_name}")
61
+
62
+ def _place_model(self, target_name: str):
63
+ """
64
+ Move the selected model to CUDA and push all others back to CPU.
65
+ """
66
+ for name, mdl in self.models.items():
67
+ if name == target_name:
68
+ self.models[name] = mdl.to(self.device)
69
+ else:
70
+ self.models[name] = mdl.to("cpu")
71
+ self.current_model_name = target_name
72
+ print(f"Model {target_name} is now on {self.device}. Others moved to CPU.")
73
 
74
  def setup_voice_presets(self):
75
  voices_dir = os.path.join(os.path.dirname(__file__), "voices")
 
96
  return np.array([])
97
 
98
  @GPU(duration=60)
99
+ def generate_podcast(self,
100
+ num_speakers: int,
101
+ script: str,
102
+ speaker_1: str = None,
103
+ speaker_2: str = None,
104
+ speaker_3: str = None,
105
+ speaker_4: str = None,
106
+ cfg_scale: float = 1.3,
107
+ model_name: str = None):
108
  """
109
  Generates a podcast as a single audio file from a script and saves it.
110
+ Non-streaming.
111
  """
112
  try:
113
+ # pick model
114
+ model_name = model_name or self.current_model_name
115
+ if model_name not in self.models:
116
+ raise gr.Error(f"Unknown model: {model_name}")
117
+
118
+ # place models on devices
119
+ self._place_model(model_name)
120
+ model = self.models[model_name]
121
+ processor = self.processors[model_name]
122
+
123
+ print(f"Using model {model_name} on {self.device}")
124
+
125
+ model.eval()
126
+ model.set_ddpm_inference_steps(num_steps=self.inference_steps)
127
+
128
  self.is_generating = True
129
+
130
  if not script.strip():
131
  raise gr.Error("Error: Please provide a script.")
132
+
 
133
  script = script.replace("’", "'")
134
+
135
  if not 1 <= num_speakers <= 4:
136
  raise gr.Error("Error: Number of speakers must be between 1 and 4.")
137
+
 
138
  selected_speakers = [speaker_1, speaker_2, speaker_3, speaker_4][:num_speakers]
139
  for i, speaker_name in enumerate(selected_speakers):
140
  if not speaker_name or speaker_name not in self.available_voices:
141
  raise gr.Error(f"Error: Please select a valid speaker for Speaker {i+1}.")
142
+
 
143
  log = f"πŸŽ™οΈ Generating podcast with {num_speakers} speakers\n"
144
+ log += f"🧠 Model: {model_name}\n"
145
  log += f"πŸ“Š Parameters: CFG Scale={cfg_scale}\n"
146
  log += f"🎭 Speakers: {', '.join(selected_speakers)}\n"
147
+
 
148
  voice_samples = []
149
  for speaker_name in selected_speakers:
150
  audio_path = self.available_voices[speaker_name]
 
151
  audio_data = self.read_audio(audio_path)
152
  if len(audio_data) == 0:
153
  raise gr.Error(f"Error: Failed to load audio for {speaker_name}")
154
  voice_samples.append(audio_data)
155
+
156
  log += f"βœ… Loaded {len(voice_samples)} voice samples\n"
157
+
 
158
  lines = script.strip().split('\n')
159
  formatted_script_lines = []
160
  for line in lines:
161
  line = line.strip()
162
  if not line:
163
  continue
 
 
164
  if line.startswith('Speaker ') and ':' in line:
165
  formatted_script_lines.append(line)
166
  else:
 
167
  speaker_id = len(formatted_script_lines) % num_speakers
168
  formatted_script_lines.append(f"Speaker {speaker_id}: {line}")
169
+
170
  formatted_script = '\n'.join(formatted_script_lines)
171
  log += f"πŸ“ Formatted script with {len(formatted_script_lines)} turns\n"
172
  log += "πŸ”„ Processing with VibeVoice...\n"
173
+
174
+ inputs = processor(
 
 
175
  text=[formatted_script],
176
  voice_samples=[voice_samples],
177
  padding=True,
178
  return_tensors="pt",
179
  return_attention_mask=True,
180
  )
181
+
 
182
  start_time = time.time()
183
+ outputs = model.generate(
 
184
  **inputs,
185
  max_new_tokens=None,
186
  cfg_scale=cfg_scale,
187
+ tokenizer=processor.tokenizer,
188
  generation_config={'do_sample': False},
189
+ verbose=False,
190
  )
191
  generation_time = time.time() - start_time
192
+
 
 
193
  if hasattr(outputs, 'speech_outputs') and outputs.speech_outputs[0] is not None:
194
  audio_tensor = outputs.speech_outputs[0]
195
  audio = audio_tensor.cpu().float().numpy()
196
  else:
197
  raise gr.Error("❌ Error: No audio was generated by the model. Please try again.")
198
+
 
199
  if audio.ndim > 1:
200
  audio = audio.squeeze()
201
+
202
+ sample_rate = 24000
203
+
 
204
  output_dir = "outputs"
205
  os.makedirs(output_dir, exist_ok=True)
206
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
207
  file_path = os.path.join(output_dir, f"podcast_{timestamp}.wav")
 
 
208
  sf.write(file_path, audio, sample_rate)
209
  print(f"πŸ’Ύ Podcast saved to {file_path}")
210
+
 
211
  total_duration = len(audio) / sample_rate
212
  log += f"⏱️ Generation completed in {generation_time:.2f} seconds\n"
213
  log += f"🎡 Final audio duration: {total_duration:.2f} seconds\n"
214
  log += f"βœ… Successfully saved podcast to: {file_path}\n"
215
+
216
  self.is_generating = False
217
  return (sample_rate, audio), log
218
 
219
  except gr.Error as e:
 
220
  self.is_generating = False
221
  error_msg = f"❌ Input Error: {str(e)}"
222
  print(error_msg)
 
 
 
223
  return None, error_msg
224
+
225
  except Exception as e:
 
226
  self.is_generating = False
227
  error_msg = f"❌ An unexpected error occurred: {str(e)}"
228
  print(error_msg)
 
229
  traceback.print_exc()
230
  return None, error_msg
231
 
 
233
 
234
 
235
  def load_example_scripts(self):
236
+ """Load example scripts from the text_examples directory."""
237
  examples_dir = os.path.join(os.path.dirname(__file__), "text_examples")
238
  self.example_scripts = []
239
+
240
+ # Check if text_examples directory exists
241
  if not os.path.exists(examples_dir):
242
+ print(f"Warning: text_examples directory not found at {examples_dir}")
243
  return
244
+
245
+ # Get all .txt files in the text_examples directory
246
+ txt_files = sorted([f for f in os.listdir(examples_dir)
247
+ if f.lower().endswith('.txt') and os.path.isfile(os.path.join(examples_dir, f))])
248
+
249
  for txt_file in txt_files:
250
+ file_path = os.path.join(examples_dir, txt_file)
251
+
252
+ import re
253
+ # Check if filename contains a time pattern like "45min", "90min", etc.
254
+ time_pattern = re.search(r'(\d+)min', txt_file.lower())
255
+ if time_pattern:
256
+ minutes = int(time_pattern.group(1))
257
+ if minutes > 15:
258
+ print(f"Skipping {txt_file}: duration {minutes} minutes exceeds 15-minute limit")
259
+ continue
260
+
261
  try:
262
+ with open(file_path, 'r', encoding='utf-8') as f:
263
  script_content = f.read().strip()
264
+
265
+ # Remove empty lines and lines with only whitespace
266
+ script_content = '\n'.join(line for line in script_content.split('\n') if line.strip())
267
+
268
+ if not script_content:
269
+ continue
270
+
271
+ # Parse the script to determine number of speakers
272
+ num_speakers = self._get_num_speakers_from_script(script_content)
273
+
274
+ # Add to examples list as [num_speakers, script_content]
275
+ self.example_scripts.append([num_speakers, script_content])
276
+ print(f"Loaded example: {txt_file} with {num_speakers} speakers")
277
+
278
  except Exception as e:
279
+ print(f"Error loading example script {txt_file}: {e}")
280
+
281
+ if self.example_scripts:
282
+ print(f"Successfully loaded {len(self.example_scripts)} example scripts")
283
+ else:
284
+ print("No example scripts were loaded")
285
 
286
 
287
  def convert_to_16_bit_wav(data):
 
294
 
295
 
296
  def create_demo_interface(demo_instance: VibeVoiceDemo):
297
+ custom_css = """ /* Modern light theme with gradients */
298
+ .gradio-container {
299
+ background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%);
300
+ font-family: 'SF Pro Display', -apple-system, BlinkMacSystemFont, sans-serif;
301
+ }
302
+
303
+ /* Header styling */
304
+ .main-header {
305
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
306
+ padding: 2rem;
307
+ border-radius: 20px;
308
+ margin-bottom: 2rem;
309
+ text-align: center;
310
+ box-shadow: 0 10px 40px rgba(102, 126, 234, 0.3);
311
+ }
312
+
313
+ .main-header h1 {
314
+ color: white;
315
+ font-size: 2.5rem;
316
+ font-weight: 700;
317
+ margin: 0;
318
+ text-shadow: 0 2px 4px rgba(0,0,0,0.3);
319
+ }
320
+
321
+ .main-header p {
322
+ color: rgba(255,255,255,0.9);
323
+ font-size: 1.1rem;
324
+ margin: 0.5rem 0 0 0;
325
+ }
326
+
327
+ /* Card styling */
328
+ .settings-card, .generation-card {
329
+ background: rgba(255, 255, 255, 0.8);
330
+ backdrop-filter: blur(10px);
331
+ border: 1px solid rgba(226, 232, 240, 0.8);
332
+ border-radius: 16px;
333
+ padding: 1.5rem;
334
+ margin-bottom: 1rem;
335
+ box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
336
+ }
337
+
338
+ /* Speaker selection styling */
339
+ .speaker-grid {
340
+ display: grid;
341
+ gap: 1rem;
342
+ margin-bottom: 1rem;
343
+ }
344
+
345
+ .speaker-item {
346
+ background: linear-gradient(135deg, #e2e8f0 0%, #cbd5e1 100%);
347
+ border: 1px solid rgba(148, 163, 184, 0.4);
348
+ border-radius: 12px;
349
+ padding: 1rem;
350
+ color: #374151;
351
+ font-weight: 500;
352
+ }
353
+
354
+ /* Streaming indicator */
355
+ .streaming-indicator {
356
+ display: inline-block;
357
+ width: 10px;
358
+ height: 10px;
359
+ background: #22c55e;
360
+ border-radius: 50%;
361
+ margin-right: 8px;
362
+ animation: pulse 1.5s infinite;
363
+ }
364
+
365
+ @keyframes pulse {
366
+ 0% { opacity: 1; transform: scale(1); }
367
+ 50% { opacity: 0.5; transform: scale(1.1); }
368
+ 100% { opacity: 1; transform: scale(1); }
369
+ }
370
+
371
+ /* Queue status styling */
372
+ .queue-status {
373
+ background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%);
374
+ border: 1px solid rgba(14, 165, 233, 0.3);
375
+ border-radius: 8px;
376
+ padding: 0.75rem;
377
+ margin: 0.5rem 0;
378
+ text-align: center;
379
+ font-size: 0.9rem;
380
+ color: #0369a1;
381
+ }
382
+
383
+ .generate-btn {
384
+ background: linear-gradient(135deg, #059669 0%, #0d9488 100%);
385
+ border: none;
386
+ border-radius: 12px;
387
+ padding: 1rem 2rem;
388
+ color: white;
389
+ font-weight: 600;
390
+ font-size: 1.1rem;
391
+ box-shadow: 0 4px 20px rgba(5, 150, 105, 0.4);
392
+ transition: all 0.3s ease;
393
+ }
394
+
395
+ .generate-btn:hover {
396
+ transform: translateY(-2px);
397
+ box-shadow: 0 6px 25px rgba(5, 150, 105, 0.6);
398
+ }
399
+
400
+ .stop-btn {
401
+ background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
402
+ border: none;
403
+ border-radius: 12px;
404
+ padding: 1rem 2rem;
405
+ color: white;
406
+ font-weight: 600;
407
+ font-size: 1.1rem;
408
+ box-shadow: 0 4px 20px rgba(239, 68, 68, 0.4);
409
+ transition: all 0.3s ease;
410
+ }
411
+
412
+ .stop-btn:hover {
413
+ transform: translateY(-2px);
414
+ box-shadow: 0 6px 25px rgba(239, 68, 68, 0.6);
415
+ }
416
+
417
+ /* Audio player styling */
418
+ .audio-output {
419
+ background: linear-gradient(135deg, #f1f5f9 0%, #e2e8f0 100%);
420
+ border-radius: 16px;
421
+ padding: 1.5rem;
422
+ border: 1px solid rgba(148, 163, 184, 0.3);
423
+ }
424
+
425
+ .complete-audio-section {
426
+ margin-top: 1rem;
427
+ padding: 1rem;
428
+ background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%);
429
+ border: 1px solid rgba(34, 197, 94, 0.3);
430
+ border-radius: 12px;
431
+ }
432
+
433
+ /* Text areas */
434
+ .script-input, .log-output {
435
+ background: rgba(255, 255, 255, 0.9) !important;
436
+ border: 1px solid rgba(148, 163, 184, 0.4) !important;
437
+ border-radius: 12px !important;
438
+ color: #1e293b !important;
439
+ font-family: 'JetBrains Mono', monospace !important;
440
+ }
441
+
442
+ .script-input::placeholder {
443
+ color: #64748b !important;
444
+ }
445
+
446
+ /* Sliders */
447
+ .slider-container {
448
+ background: rgba(248, 250, 252, 0.8);
449
+ border: 1px solid rgba(226, 232, 240, 0.6);
450
+ border-radius: 8px;
451
+ padding: 1rem;
452
+ margin: 0.5rem 0;
453
+ }
454
+
455
+ /* Labels and text */
456
+ .gradio-container label {
457
+ color: #374151 !important;
458
+ font-weight: 600 !important;
459
+ }
460
+
461
+ .gradio-container .markdown {
462
+ color: #1f2937 !important;
463
+ }
464
+
465
+ /* Responsive design */
466
+ @media (max-width: 768px) {
467
+ .main-header h1 { font-size: 2rem; }
468
+ .settings-card, .generation-card { padding: 1rem; }
469
+ }
470
+
471
+ /* Random example button styling - more subtle professional color */
472
+ .random-btn {
473
+ background: linear-gradient(135deg, #64748b 0%, #475569 100%);
474
+ border: none;
475
+ border-radius: 12px;
476
+ padding: 1rem 1.5rem;
477
+ color: white;
478
+ font-weight: 600;
479
+ font-size: 1rem;
480
+ box-shadow: 0 4px 20px rgba(100, 116, 139, 0.3);
481
+ transition: all 0.3s ease;
482
+ display: inline-flex;
483
+ align-items: center;
484
+ gap: 0.5rem;
485
+ }
486
+
487
+ .random-btn:hover {
488
+ transform: translateY(-2px);
489
+ box-shadow: 0 6px 25px rgba(100, 116, 139, 0.4);
490
+ background: linear-gradient(135deg, #475569 0%, #334155 100%);
491
+ }
492
+ """
493
 
494
  with gr.Blocks(
495
  title="VibeVoice - AI Podcast Generator",
 
500
  neutral_hue="slate",
501
  )
502
  ) as interface:
503
+
 
504
  gr.HTML("""
505
  <div class="main-header">
506
  <h1>πŸŽ™οΈ Vibe Podcasting</h1>
507
  <p>Generating Long-form Multi-speaker AI Podcast with VibeVoice</p>
508
  </div>
509
  """)
510
+
511
  with gr.Row():
 
512
  with gr.Column(scale=1, elem_classes="settings-card"):
513
+ gr.Markdown("### πŸŽ›οΈ Podcast Settings")
514
+
515
+ # NEW - model dropdown
516
+ model_dropdown = gr.Dropdown(
517
+ choices=list(demo_instance.models.keys()),
518
+ value=demo_instance.current_model_name,
519
+ label="Model",
520
+ )
521
+
522
  num_speakers = gr.Slider(
523
  minimum=1, maximum=4, value=2, step=1,
524
  label="Number of Speakers",
525
  elem_classes="slider-container"
526
  )
527
+
528
+ gr.Markdown("### 🎭 Speaker Selection")
529
  available_speaker_names = list(demo_instance.available_voices.keys())
530
  default_speakers = ['en-Alice_woman', 'en-Carter_man', 'en-Frank_man', 'en-Maya_woman']
531
 
 
540
  elem_classes="speaker-item"
541
  )
542
  speaker_selections.append(speaker)
543
+
544
+ gr.Markdown("### βš™οΈ Advanced Settings")
545
  with gr.Accordion("Generation Parameters", open=False):
546
  cfg_scale = gr.Slider(
547
  minimum=1.0, maximum=2.0, value=1.3, step=0.05,
548
  label="CFG Scale (Guidance Strength)",
549
  elem_classes="slider-container"
550
  )
551
+
 
552
  with gr.Column(scale=2, elem_classes="generation-card"):
553
+ gr.Markdown("### πŸ“ Script Input")
554
  script_input = gr.Textbox(
555
  label="Conversation Script",
556
  placeholder="Enter your podcast script here...",
 
558
  max_lines=20,
559
  elem_classes="script-input"
560
  )
561
+
562
  with gr.Row():
563
  random_example_btn = gr.Button(
564
  "🎲 Random Example", size="lg",
 
568
  "πŸš€ Generate Podcast", size="lg",
569
  variant="primary", elem_classes="generate-btn", scale=2
570
  )
571
+
572
+ gr.Markdown("### 🎡 Generated Podcast")
 
573
  complete_audio_output = gr.Audio(
574
  label="Complete Podcast (Download)",
575
  type="numpy",
 
578
  show_download_button=True,
579
  visible=True
580
  )
581
+
582
  log_output = gr.Textbox(
583
  label="Generation Log",
584
  lines=8, max_lines=15,
585
  interactive=False,
586
  elem_classes="log-output"
587
  )
588
+
 
589
  def update_speaker_visibility(num_speakers):
590
  return [gr.update(visible=(i < num_speakers)) for i in range(4)]
591
+
592
  num_speakers.change(
593
  fn=update_speaker_visibility,
594
  inputs=[num_speakers],
595
  outputs=speaker_selections
596
  )
597
 
598
+ def generate_podcast_wrapper(model_choice, num_speakers, script, *speakers_and_params):
599
  try:
600
  speakers = speakers_and_params[:4]
601
+ cfg_scale_val = speakers_and_params[4]
602
  audio, log = demo_instance.generate_podcast(
603
  num_speakers=int(num_speakers),
604
  script=script,
 
606
  speaker_2=speakers[1],
607
  speaker_3=speakers[2],
608
  speaker_4=speakers[3],
609
+ cfg_scale=cfg_scale_val,
610
+ model_name=model_choice
611
  )
612
  return audio, log
613
  except Exception as e:
 
616
 
617
  generate_btn.click(
618
  fn=generate_podcast_wrapper,
619
+ inputs=[model_dropdown, num_speakers, script_input] + speaker_selections + [cfg_scale],
620
  outputs=[complete_audio_output, log_output],
621
  queue=True
622
  )
 
637
  outputs=[num_speakers, script_input],
638
  queue=False
639
  )
640
+
641
+ gr.Markdown("### πŸ“š Example Scripts")
642
  examples = getattr(demo_instance, "example_scripts", []) or [
643
  [1, "Speaker 1: Welcome to our AI podcast demo. This is a sample script."]
644
  ]
 
652
 
653
 
654
 
655
+
656
  def run_demo(
657
+ model_paths: dict = None,
658
  device: str = "cuda",
659
  inference_steps: int = 5,
660
  share: bool = True,
661
  ):
662
+ """
663
+ model_paths default includes two entries. Replace paths as needed.
664
+ """
665
+ if model_paths is None:
666
+ model_paths = {
667
+ "VibeVoice-Large": "microsoft/VibeVoice-Large",
668
+ "VibeVoice-1.1B": "microsoft/VibeVoice-1.1B"
669
+ }
670
+
671
  set_seed(42)
672
+ demo_instance = VibeVoiceDemo(model_paths, device, inference_steps)
673
  interface = create_demo_interface(demo_instance)
674
  interface.queue().launch(
675
  share=share,
 
679
  )
680
 
681
 
682
+
683
  if __name__ == "__main__":
684
  run_demo()