Spaces:
Running
Running
import os | |
import tempfile | |
import time | |
from typing import List, Tuple | |
import gradio as gr | |
import torch | |
import torchaudio | |
# import spaces | |
from dataclasses import dataclass | |
from generator import Segment, load_csm_1b | |
from huggingface_hub import login | |
# Disable torch compile feature to avoid triton error | |
torch._dynamo.config.suppress_errors = True | |
# Check if GPU is available and configure the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Login to Hugging Face Hub if token is available | |
def login_huggingface(): | |
hf_token = os.environ.get("HF_TOKEN") | |
if hf_token: | |
print("Logging in to Hugging Face Hub...") | |
login(token=hf_token) | |
print("Login successful!") | |
else: | |
print("HF_TOKEN not found in environment variables. Some models may not be accessible.") | |
# Login at startup | |
login_huggingface() | |
# Global variables to track model state | |
generator = None | |
model_loaded = False | |
# Function to load model in ZeroGPU | |
# @spaces.GPU(duration=30) | |
def initialize_model(): | |
global generator, model_loaded | |
if not model_loaded: | |
print("Loading CSM-1B model in GPU...") | |
generator = load_csm_1b(device="cuda") | |
model_loaded = True | |
print("Model loaded successfully!") | |
return generator | |
# Function to get the loaded model | |
# @spaces.GPU(duration=30) | |
def get_model(): | |
global generator, model_loaded | |
if not model_loaded: | |
return initialize_model() | |
return generator | |
# Preload model if environment variable is set | |
def preload_model_if_needed(): | |
if os.environ.get("PRELOAD_MODEL", "").lower() in ("true", "1", "yes"): | |
print("PRELOAD_MODEL is set. Attempting to preload model...") | |
try: | |
# We can't directly call initialize_model() here because it's decorated with @spaces.GPU | |
# Instead, we'll set a flag that will be checked when the first request comes in | |
global model_loaded | |
model_loaded = False | |
print("Model will be loaded on first request.") | |
except Exception as e: | |
print(f"Error during model preloading setup: {e}") | |
else: | |
print("PRELOAD_MODEL is not set. Model will be loaded on demand.") | |
# Call preload function at startup | |
preload_model_if_needed() | |
# Function to convert audio to tensor | |
def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]: | |
waveform, sample_rate = torchaudio.load(audio_path) | |
waveform = waveform.mean(dim=0) # Convert stereo to mono if needed | |
return waveform, sample_rate | |
# Function to save audio tensor to file | |
def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str: | |
# Lưu file vào thư mục hiện tại hoặc thư mục files mà Gradio mặc định sử dụng | |
output_path = f"csm1b_output_{int(time.time())}.wav" | |
torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate) | |
return output_path | |
# Function to generate speech from text using ZeroGPU | |
# @spaces.GPU(duration=30) | |
def generate_speech( | |
text: str, | |
speaker_id: int, | |
context_audio_path1: str = None, | |
context_text1: str = None, | |
context_speaker1: int = 0, | |
context_audio_path2: str = None, | |
context_text2: str = None, | |
context_speaker2: int = 1, | |
max_duration_ms: float = 30000, | |
temperature: float = 0.9, | |
top_k: int = 50, | |
progress=gr.Progress() | |
) -> str: | |
try: | |
# Get the loaded model | |
generator = get_model() | |
# Prepare context | |
context = [] | |
progress(0.1, "Processing context...") | |
# Process context 1 | |
if context_audio_path1 and context_text1: | |
waveform, sample_rate = audio_to_tensor(context_audio_path1) | |
# Resample if needed | |
if sample_rate != generator.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
context.append(Segment(speaker=context_speaker1, text=context_text1, audio=waveform)) | |
# Process context 2 | |
if context_audio_path2 and context_text2: | |
waveform, sample_rate = audio_to_tensor(context_audio_path2) | |
# Resample if needed | |
if sample_rate != generator.sample_rate: | |
waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) | |
context.append(Segment(speaker=context_speaker2, text=context_text2, audio=waveform)) | |
progress(0.3, "Generating audio...") | |
# Generate audio from text | |
audio = generator.generate( | |
text=text, | |
speaker=speaker_id, | |
context=context, | |
max_audio_length_ms=max_duration_ms, | |
# temperature=temperature, | |
# topk=top_k | |
) | |
progress(0.8, "Saving audio...") | |
# Save audio to file | |
# output_path = save_audio(audio, generator.sample_rate) | |
output_path = f"csm1b_output_{int(time.time())}.wav" | |
progress(1.0, "Completed!") | |
return output_path | |
except Exception as e: | |
# Handle ZeroGPU quota exceeded error | |
error_message = str(e) | |
if "GPU quota exceeded" in error_message: | |
# Extract wait time from error message | |
import re | |
wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) | |
wait_time = wait_time_match.group(1) if wait_time_match else "some time" | |
return f"GPU quota exceeded. Please try again in {wait_time}." | |
return f"GPU error: {error_message}" | |
except Exception as e: | |
return f"Error generating speech: {str(e)}" | |
# Function to generate simple speech without context | |
# @spaces.GPU(duration=30) | |
def generate_speech_simple( | |
text: str, | |
speaker_id: int, | |
max_duration_ms: float = 30000, | |
temperature: float = 0.9, | |
top_k: int = 50, | |
progress=gr.Progress() | |
) -> str: | |
try: | |
# Get the loaded model | |
generator = get_model() | |
progress(0.3, "Generating audio...") | |
# Generate audio from text | |
audio = generator.generate( | |
text=text, | |
speaker=speaker_id, | |
context=[], # No context | |
max_audio_length_ms=max_duration_ms, | |
# temperature=temperature, | |
# topk=top_k | |
) | |
progress(0.8, "Saving audio...") | |
# Save audio to file | |
# output_path = save_audio(audio, generator.sample_rate) | |
output_path = f"csm1b_output_{int(time.time())}.wav" | |
torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate) | |
print(f"Audio saved to {output_path}") | |
progress(1.0, "Completed!") | |
return output_path | |
except Exception as e: | |
# Handle ZeroGPU quota exceeded error | |
error_message = str(e) | |
if "GPU quota exceeded" in error_message: | |
# Extract wait time from error message | |
import re | |
wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) | |
wait_time = wait_time_match.group(1) if wait_time_match else "some time" | |
return f"GPU quota exceeded. Please try again in {wait_time}." | |
return f"GPU error: {error_message}" | |
except Exception as e: | |
return f"Error generating speech: {str(e)}" | |
# Create Gradio interface | |
def create_demo(): | |
with gr.Blocks(title="CSM-1B Text-to-Speech") as demo: | |
gr.Markdown("# CSM-1B Text-to-Speech Demo") | |
gr.Markdown("CSM-1B (Collaborative Speech Model) is an advanced text-to-speech model capable of generating natural-sounding speech from text.") | |
with gr.Tab("Simple Audio Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
text_input = gr.Textbox( | |
label="Text to convert to speech", | |
placeholder="Enter the text you want to convert to speech...", | |
lines=5 | |
) | |
speaker_id = gr.Number( | |
label="Speaker ID", | |
value=0, | |
precision=0, | |
minimum=0, | |
maximum=10 | |
) | |
with gr.Row(): | |
max_duration = gr.Slider( | |
label="Maximum Duration (ms)", | |
minimum=1000, | |
maximum=90000, | |
value=30000, | |
step=1000 | |
) | |
# temperature = gr.Slider( | |
# label="Temperature", | |
# minimum=0.1, | |
# maximum=1.5, | |
# value=0.9, | |
# step=0.1 | |
# ) | |
# top_k = gr.Slider( | |
# label="Top-K", | |
# minimum=1, | |
# maximum=100, | |
# value=50, | |
# step=1 | |
# ) | |
generate_btn = gr.Button("Generate Audio") | |
with gr.Column(): | |
output_audio = gr.Audio(label="Output Audio", type="filepath", autoplay=True) | |
with gr.Tab("Audio Generation with Context"): | |
gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.") | |
with gr.Row(): | |
with gr.Column(): | |
context_text1 = gr.Textbox(label="Context Text 1", lines=2) | |
context_audio1 = gr.Audio(label="Context Audio 1", type="filepath") | |
context_speaker1 = gr.Number(label="Speaker ID 1", value=0, precision=0) | |
context_text2 = gr.Textbox(label="Context Text 2", lines=2) | |
context_audio2 = gr.Audio(label="Context Audio 2", type="filepath") | |
context_speaker2 = gr.Number(label="Speaker ID 2", value=1, precision=0) | |
text_input_context = gr.Textbox( | |
label="Text to convert to speech", | |
placeholder="Enter the text you want to convert to speech...", | |
lines=3 | |
) | |
speaker_id_context = gr.Number( | |
label="Speaker ID", | |
value=0, | |
precision=0 | |
) | |
with gr.Row(): | |
max_duration_context = gr.Slider( | |
label="Maximum Duration (ms)", | |
minimum=1000, | |
maximum=90000, | |
value=30000, | |
step=1000 | |
) | |
# temperature_context = gr.Slider( | |
# label="Temperature", | |
# minimum=0.1, | |
# maximum=1.5, | |
# value=0.9, | |
# step=0.1 | |
# ) | |
# top_k_context = gr.Slider( | |
# label="Top-K", | |
# minimum=1, | |
# maximum=100, | |
# value=50, | |
# step=1 | |
# ) | |
generate_context_btn = gr.Button("Generate Audio with Context") | |
with gr.Column(): | |
output_audio_context = gr.Audio(label="Output Audio", type="filepath", autoplay=True) | |
# Add Hugging Face configuration tab | |
with gr.Tab("Configuration"): | |
gr.Markdown("### Hugging Face Token Configuration") | |
gr.Markdown(""" | |
To use the CSM-1B model, you need access to the model on Hugging Face. | |
You can configure your token by: | |
1. Create a token at [Hugging Face Settings](https://huggingface.co/settings/tokens) | |
2. Set the `HF_TOKEN` environment variable with your token value | |
Note: In Hugging Face Spaces, you can set environment variables in the Space Settings. | |
""") | |
hf_token_input = gr.Textbox( | |
label="Hugging Face Token (Only for this session)", | |
placeholder="Enter your token...", | |
type="password" | |
) | |
def set_token(token): | |
if token: | |
os.environ["HF_TOKEN"] = token | |
login(token=token) | |
return "Token set successfully! You can now load the model." | |
return "Invalid token. Please enter a valid token." | |
set_token_btn = gr.Button("Set Token") | |
token_status = gr.Textbox(label="Status", interactive=False) | |
set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status) | |
# Add GPU information tab | |
with gr.Tab("GPU Information"): | |
gr.Markdown("### About ZeroGPU") | |
gr.Markdown(""" | |
This application uses Hugging Face Spaces' ZeroGPU to optimize GPU usage. | |
ZeroGPU helps free up GPU memory when not in use, saving resources and improving performance. | |
When you generate audio, the GPU will be used automatically and released after completion. | |
Note: In the ZeroGPU environment, CUDA is not initialized in the main process, but only in functions with the @spaces.GPU decorator. | |
""") | |
gr.Markdown("### GPU Quota Information") | |
gr.Markdown(""" | |
Hugging Face Spaces has GPU quota limitations: | |
- Each GPU operation has a default duration of 60 seconds | |
- We've reduced this to 30 seconds for audio generation and 10 seconds for GPU checks | |
- If you exceed your quota, you'll need to wait for it to reset (usually a few hours) | |
- For better performance, try generating shorter audio clips | |
If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again. | |
""") | |
# @spaces.GPU(duration=10) | |
def check_gpu(): | |
if torch.cuda.is_available(): | |
gpu_name = torch.cuda.get_device_name(0) | |
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) | |
return f"GPU: {gpu_name}\nMemory: {gpu_memory:.2f} GB" | |
else: | |
return "No GPU found. The application will run on CPU." | |
check_gpu_btn = gr.Button("Check GPU") | |
gpu_info = gr.Textbox(label="GPU Information", interactive=False) | |
check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info) | |
# Add model loading button | |
load_model_btn = gr.Button("Load Model") | |
model_status = gr.Textbox(label="Model Status", interactive=False) | |
# @spaces.GPU(duration=10) | |
def load_model_and_report(): | |
global model_loaded | |
if model_loaded: | |
return "Model has already been loaded!" | |
else: | |
initialize_model() | |
return "Model loaded successfully!" | |
load_model_btn.click(fn=load_model_and_report, inputs=None, outputs=model_status) | |
# Connect components | |
generate_btn.click( | |
fn=generate_speech_simple, | |
inputs=[ | |
text_input, | |
speaker_id, | |
max_duration, | |
# temperature, | |
# top_k | |
], | |
outputs=output_audio | |
) | |
generate_context_btn.click( | |
fn=generate_speech, | |
inputs=[ | |
text_input_context, | |
speaker_id_context, | |
context_audio1, | |
context_text1, | |
context_speaker1, | |
context_audio2, | |
context_text2, | |
context_speaker2, | |
max_duration_context, | |
# temperature_context, | |
# top_k_context | |
], | |
outputs=output_audio_context | |
) | |
return demo | |
# Launch the application | |
if __name__ == "__main__": | |
demo = create_demo() | |
demo.queue().launch() |