phucbienvan commited on
Commit
633ab26
·
1 Parent(s): e02c9de

fix zero gpu

Browse files
Files changed (10) hide show
  1. .gitattributes +0 -0
  2. .gitignore +0 -0
  3. README.md +0 -0
  4. app.py +58 -50
  5. generator.py +1 -0
  6. hf_requirements.txt +0 -0
  7. models.py +0 -0
  8. requirements.txt +1 -1
  9. test_model.py +0 -0
  10. watermarking.py +0 -0
.gitattributes CHANGED
File without changes
.gitignore CHANGED
File without changes
README.md CHANGED
File without changes
app.py CHANGED
@@ -11,6 +11,7 @@ from dataclasses import dataclass
11
  from generator import Segment, load_csm_1b
12
  from huggingface_hub import login
13
 
 
14
  # Disable torch compile feature to avoid triton error
15
  torch._dynamo.config.suppress_errors = True
16
 
@@ -36,7 +37,7 @@ generator = None
36
  model_loaded = False
37
 
38
  # Function to load model in ZeroGPU
39
- @spaces.GPU(duration=30)
40
  def initialize_model():
41
  global generator, model_loaded
42
  if not model_loaded:
@@ -47,7 +48,7 @@ def initialize_model():
47
  return generator
48
 
49
  # Function to get the loaded model
50
- @spaces.GPU(duration=30)
51
  def get_model():
52
  global generator, model_loaded
53
  if not model_loaded:
@@ -80,13 +81,13 @@ def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]:
80
 
81
  # Function to save audio tensor to file
82
  def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str:
83
- temp_dir = tempfile.gettempdir()
84
- output_path = os.path.join(temp_dir, f"csm1b_output_{int(time.time())}.wav")
85
  torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate)
86
  return output_path
87
 
88
  # Function to generate speech from text using ZeroGPU
89
- @spaces.GPU(duration=30)
90
  def generate_speech(
91
  text: str,
92
  speaker_id: int,
@@ -132,13 +133,14 @@ def generate_speech(
132
  speaker=speaker_id,
133
  context=context,
134
  max_audio_length_ms=max_duration_ms,
135
- temperature=temperature,
136
- topk=top_k
137
  )
138
 
139
  progress(0.8, "Saving audio...")
140
  # Save audio to file
141
- output_path = save_audio(audio, generator.sample_rate)
 
142
 
143
  progress(1.0, "Completed!")
144
  return output_path
@@ -156,7 +158,7 @@ def generate_speech(
156
  return f"Error generating speech: {str(e)}"
157
 
158
  # Function to generate simple speech without context
159
- @spaces.GPU(duration=30)
160
  def generate_speech_simple(
161
  text: str,
162
  speaker_id: int,
@@ -176,17 +178,23 @@ def generate_speech_simple(
176
  speaker=speaker_id,
177
  context=[], # No context
178
  max_audio_length_ms=max_duration_ms,
179
- temperature=temperature,
180
- topk=top_k
181
  )
182
 
183
  progress(0.8, "Saving audio...")
184
  # Save audio to file
185
- output_path = save_audio(audio, generator.sample_rate)
 
 
 
 
 
 
186
 
187
  progress(1.0, "Completed!")
188
  return output_path
189
- except spaces.zero.gradio.HTMLError as e:
190
  # Handle ZeroGPU quota exceeded error
191
  error_message = str(e)
192
  if "GPU quota exceeded" in error_message:
@@ -229,25 +237,25 @@ def create_demo():
229
  value=30000,
230
  step=1000
231
  )
232
- temperature = gr.Slider(
233
- label="Temperature",
234
- minimum=0.1,
235
- maximum=1.5,
236
- value=0.9,
237
- step=0.1
238
- )
239
- top_k = gr.Slider(
240
- label="Top-K",
241
- minimum=1,
242
- maximum=100,
243
- value=50,
244
- step=1
245
- )
246
 
247
  generate_btn = gr.Button("Generate Audio")
248
 
249
  with gr.Column():
250
- output_audio = gr.Audio(label="Output Audio", type="filepath")
251
 
252
  with gr.Tab("Audio Generation with Context"):
253
  gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.")
@@ -281,25 +289,25 @@ def create_demo():
281
  value=30000,
282
  step=1000
283
  )
284
- temperature_context = gr.Slider(
285
- label="Temperature",
286
- minimum=0.1,
287
- maximum=1.5,
288
- value=0.9,
289
- step=0.1
290
- )
291
- top_k_context = gr.Slider(
292
- label="Top-K",
293
- minimum=1,
294
- maximum=100,
295
- value=50,
296
- step=1
297
- )
298
 
299
  generate_context_btn = gr.Button("Generate Audio with Context")
300
 
301
  with gr.Column():
302
- output_audio_context = gr.Audio(label="Output Audio", type="filepath")
303
 
304
  # Add Hugging Face configuration tab
305
  with gr.Tab("Configuration"):
@@ -357,7 +365,7 @@ def create_demo():
357
  If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again.
358
  """)
359
 
360
- @spaces.GPU(duration=10)
361
  def check_gpu():
362
  if torch.cuda.is_available():
363
  gpu_name = torch.cuda.get_device_name(0)
@@ -375,7 +383,7 @@ def create_demo():
375
  load_model_btn = gr.Button("Load Model")
376
  model_status = gr.Textbox(label="Model Status", interactive=False)
377
 
378
- @spaces.GPU(duration=10)
379
  def load_model_and_report():
380
  global model_loaded
381
  if model_loaded:
@@ -393,8 +401,8 @@ def create_demo():
393
  text_input,
394
  speaker_id,
395
  max_duration,
396
- temperature,
397
- top_k
398
  ],
399
  outputs=output_audio
400
  )
@@ -411,8 +419,8 @@ def create_demo():
411
  context_text2,
412
  context_speaker2,
413
  max_duration_context,
414
- temperature_context,
415
- top_k_context
416
  ],
417
  outputs=output_audio_context
418
  )
@@ -422,4 +430,4 @@ def create_demo():
422
  # Launch the application
423
  if __name__ == "__main__":
424
  demo = create_demo()
425
- demo.queue().launch()
 
11
  from generator import Segment, load_csm_1b
12
  from huggingface_hub import login
13
 
14
+
15
  # Disable torch compile feature to avoid triton error
16
  torch._dynamo.config.suppress_errors = True
17
 
 
37
  model_loaded = False
38
 
39
  # Function to load model in ZeroGPU
40
+ # @spaces.GPU(duration=30)
41
  def initialize_model():
42
  global generator, model_loaded
43
  if not model_loaded:
 
48
  return generator
49
 
50
  # Function to get the loaded model
51
+ # @spaces.GPU(duration=30)
52
  def get_model():
53
  global generator, model_loaded
54
  if not model_loaded:
 
81
 
82
  # Function to save audio tensor to file
83
  def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str:
84
+ # Lưu file vào thư mục hiện tại hoặc thư mục files mà Gradio mặc định sử dụng
85
+ output_path = f"csm1b_output_{int(time.time())}.wav"
86
  torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate)
87
  return output_path
88
 
89
  # Function to generate speech from text using ZeroGPU
90
+ # @spaces.GPU(duration=30)
91
  def generate_speech(
92
  text: str,
93
  speaker_id: int,
 
133
  speaker=speaker_id,
134
  context=context,
135
  max_audio_length_ms=max_duration_ms,
136
+ # temperature=temperature,
137
+ # topk=top_k
138
  )
139
 
140
  progress(0.8, "Saving audio...")
141
  # Save audio to file
142
+ # output_path = save_audio(audio, generator.sample_rate)
143
+ output_path = f"csm1b_output_{int(time.time())}.wav"
144
 
145
  progress(1.0, "Completed!")
146
  return output_path
 
158
  return f"Error generating speech: {str(e)}"
159
 
160
  # Function to generate simple speech without context
161
+ # @spaces.GPU(duration=30)
162
  def generate_speech_simple(
163
  text: str,
164
  speaker_id: int,
 
178
  speaker=speaker_id,
179
  context=[], # No context
180
  max_audio_length_ms=max_duration_ms,
181
+ # temperature=temperature,
182
+ # topk=top_k
183
  )
184
 
185
  progress(0.8, "Saving audio...")
186
  # Save audio to file
187
+ # output_path = save_audio(audio, generator.sample_rate)
188
+ output_path = f"csm1b_output_{int(time.time())}.wav"
189
+ torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate)
190
+
191
+
192
+
193
+ print(f"Audio saved to {output_path}")
194
 
195
  progress(1.0, "Completed!")
196
  return output_path
197
+ except Exception as e:
198
  # Handle ZeroGPU quota exceeded error
199
  error_message = str(e)
200
  if "GPU quota exceeded" in error_message:
 
237
  value=30000,
238
  step=1000
239
  )
240
+ # temperature = gr.Slider(
241
+ # label="Temperature",
242
+ # minimum=0.1,
243
+ # maximum=1.5,
244
+ # value=0.9,
245
+ # step=0.1
246
+ # )
247
+ # top_k = gr.Slider(
248
+ # label="Top-K",
249
+ # minimum=1,
250
+ # maximum=100,
251
+ # value=50,
252
+ # step=1
253
+ # )
254
 
255
  generate_btn = gr.Button("Generate Audio")
256
 
257
  with gr.Column():
258
+ output_audio = gr.Audio(label="Output Audio", type="filepath", autoplay=True)
259
 
260
  with gr.Tab("Audio Generation with Context"):
261
  gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.")
 
289
  value=30000,
290
  step=1000
291
  )
292
+ # temperature_context = gr.Slider(
293
+ # label="Temperature",
294
+ # minimum=0.1,
295
+ # maximum=1.5,
296
+ # value=0.9,
297
+ # step=0.1
298
+ # )
299
+ # top_k_context = gr.Slider(
300
+ # label="Top-K",
301
+ # minimum=1,
302
+ # maximum=100,
303
+ # value=50,
304
+ # step=1
305
+ # )
306
 
307
  generate_context_btn = gr.Button("Generate Audio with Context")
308
 
309
  with gr.Column():
310
+ output_audio_context = gr.Audio(label="Output Audio", type="filepath", autoplay=True)
311
 
312
  # Add Hugging Face configuration tab
313
  with gr.Tab("Configuration"):
 
365
  If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again.
366
  """)
367
 
368
+ # @spaces.GPU(duration=10)
369
  def check_gpu():
370
  if torch.cuda.is_available():
371
  gpu_name = torch.cuda.get_device_name(0)
 
383
  load_model_btn = gr.Button("Load Model")
384
  model_status = gr.Textbox(label="Model Status", interactive=False)
385
 
386
+ # @spaces.GPU(duration=10)
387
  def load_model_and_report():
388
  global model_loaded
389
  if model_loaded:
 
401
  text_input,
402
  speaker_id,
403
  max_duration,
404
+ # temperature,
405
+ # top_k
406
  ],
407
  outputs=output_audio
408
  )
 
419
  context_text2,
420
  context_speaker2,
421
  max_duration_context,
422
+ # temperature_context,
423
+ # top_k_context
424
  ],
425
  outputs=output_audio_context
426
  )
 
430
  # Launch the application
431
  if __name__ == "__main__":
432
  demo = create_demo()
433
+ demo.queue().launch(share=True)
generator.py CHANGED
@@ -178,6 +178,7 @@ def load_csm_1b(device: str = "cuda") -> Generator:
178
  try:
179
  # In ZeroGPU, CUDA should not be initialized in the main process
180
  # Only move the model to GPU when called in a function with the @spaces.GPU decorator
 
181
  if 'cuda' in device and not torch.cuda.is_initialized():
182
  # Use CPU for the main process
183
  model = Model.from_pretrained("sesame/csm-1b")
 
178
  try:
179
  # In ZeroGPU, CUDA should not be initialized in the main process
180
  # Only move the model to GPU when called in a function with the @spaces.GPU decorator
181
+ print(f"Loading model on {device}")
182
  if 'cuda' in device and not torch.cuda.is_initialized():
183
  # Use CPU for the main process
184
  model = Model.from_pretrained("sesame/csm-1b")
hf_requirements.txt CHANGED
File without changes
models.py CHANGED
File without changes
requirements.txt CHANGED
@@ -3,7 +3,7 @@ torchaudio==2.4.0
3
  tokenizers==0.21.0
4
  transformers==4.49.0
5
  huggingface_hub==0.28.1
6
- moshi==0.2.2
7
  torchtune==0.4.0
8
  torchao==0.9.0
9
  silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
 
3
  tokenizers==0.21.0
4
  transformers==4.49.0
5
  huggingface_hub==0.28.1
6
+ # moshi==0.2.2
7
  torchtune==0.4.0
8
  torchao==0.9.0
9
  silentcipher @ git+https://github.com/SesameAILabs/silentcipher@master
test_model.py CHANGED
File without changes
watermarking.py CHANGED
File without changes