import os import tempfile import time from typing import List, Tuple import gradio as gr import torch import torchaudio # import spaces from dataclasses import dataclass from generator import Segment, load_csm_1b from huggingface_hub import login # Disable torch compile feature to avoid triton error torch._dynamo.config.suppress_errors = True # Check if GPU is available and configure the device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Login to Hugging Face Hub if token is available def login_huggingface(): hf_token = os.environ.get("HF_TOKEN") if hf_token: print("Logging in to Hugging Face Hub...") login(token=hf_token) print("Login successful!") else: print("HF_TOKEN not found in environment variables. Some models may not be accessible.") # Login at startup login_huggingface() # Global variables to track model state generator = None model_loaded = False # Function to load model in ZeroGPU # @spaces.GPU(duration=30) def initialize_model(): global generator, model_loaded if not model_loaded: print("Loading CSM-1B model in GPU...") generator = load_csm_1b(device="cuda") model_loaded = True print("Model loaded successfully!") return generator # Function to get the loaded model # @spaces.GPU(duration=30) def get_model(): global generator, model_loaded if not model_loaded: return initialize_model() return generator # Preload model if environment variable is set def preload_model_if_needed(): if os.environ.get("PRELOAD_MODEL", "").lower() in ("true", "1", "yes"): print("PRELOAD_MODEL is set. Attempting to preload model...") try: # We can't directly call initialize_model() here because it's decorated with @spaces.GPU # Instead, we'll set a flag that will be checked when the first request comes in global model_loaded model_loaded = False print("Model will be loaded on first request.") except Exception as e: print(f"Error during model preloading setup: {e}") else: print("PRELOAD_MODEL is not set. Model will be loaded on demand.") # Call preload function at startup preload_model_if_needed() # Function to convert audio to tensor def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]: waveform, sample_rate = torchaudio.load(audio_path) waveform = waveform.mean(dim=0) # Convert stereo to mono if needed return waveform, sample_rate # Function to save audio tensor to file def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str: # Lưu file vào thư mục hiện tại hoặc thư mục files mà Gradio mặc định sử dụng output_path = f"csm1b_output_{int(time.time())}.wav" torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate) return output_path # Function to generate speech from text using ZeroGPU # @spaces.GPU(duration=30) def generate_speech( text: str, speaker_id: int, context_audio_path1: str = None, context_text1: str = None, context_speaker1: int = 0, context_audio_path2: str = None, context_text2: str = None, context_speaker2: int = 1, max_duration_ms: float = 30000, temperature: float = 0.9, top_k: int = 50, progress=gr.Progress() ) -> str: try: # Get the loaded model generator = get_model() # Prepare context context = [] progress(0.1, "Processing context...") # Process context 1 if context_audio_path1 and context_text1: waveform, sample_rate = audio_to_tensor(context_audio_path1) # Resample if needed if sample_rate != generator.sample_rate: waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) context.append(Segment(speaker=context_speaker1, text=context_text1, audio=waveform)) # Process context 2 if context_audio_path2 and context_text2: waveform, sample_rate = audio_to_tensor(context_audio_path2) # Resample if needed if sample_rate != generator.sample_rate: waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate) context.append(Segment(speaker=context_speaker2, text=context_text2, audio=waveform)) progress(0.3, "Generating audio...") # Generate audio from text audio = generator.generate( text=text, speaker=speaker_id, context=context, max_audio_length_ms=max_duration_ms, # temperature=temperature, # topk=top_k ) progress(0.8, "Saving audio...") # Save audio to file # output_path = save_audio(audio, generator.sample_rate) output_path = f"csm1b_output_{int(time.time())}.wav" progress(1.0, "Completed!") return output_path except Exception as e: # Handle ZeroGPU quota exceeded error error_message = str(e) if "GPU quota exceeded" in error_message: # Extract wait time from error message import re wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) wait_time = wait_time_match.group(1) if wait_time_match else "some time" return f"GPU quota exceeded. Please try again in {wait_time}." return f"GPU error: {error_message}" except Exception as e: return f"Error generating speech: {str(e)}" # Function to generate simple speech without context # @spaces.GPU(duration=30) def generate_speech_simple( text: str, speaker_id: int, max_duration_ms: float = 30000, temperature: float = 0.9, top_k: int = 50, progress=gr.Progress() ) -> str: try: # Get the loaded model generator = get_model() progress(0.3, "Generating audio...") # Generate audio from text audio = generator.generate( text=text, speaker=speaker_id, context=[], # No context max_audio_length_ms=max_duration_ms, # temperature=temperature, # topk=top_k ) progress(0.8, "Saving audio...") # Save audio to file # output_path = save_audio(audio, generator.sample_rate) output_path = f"csm1b_output_{int(time.time())}.wav" torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate) print(f"Audio saved to {output_path}") progress(1.0, "Completed!") return output_path except Exception as e: # Handle ZeroGPU quota exceeded error error_message = str(e) if "GPU quota exceeded" in error_message: # Extract wait time from error message import re wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message) wait_time = wait_time_match.group(1) if wait_time_match else "some time" return f"GPU quota exceeded. Please try again in {wait_time}." return f"GPU error: {error_message}" except Exception as e: return f"Error generating speech: {str(e)}" # Create Gradio interface def create_demo(): with gr.Blocks(title="CSM-1B Text-to-Speech") as demo: gr.Markdown("# CSM-1B Text-to-Speech Demo") gr.Markdown("CSM-1B (Collaborative Speech Model) is an advanced text-to-speech model capable of generating natural-sounding speech from text.") with gr.Tab("Simple Audio Generation"): with gr.Row(): with gr.Column(): text_input = gr.Textbox( label="Text to convert to speech", placeholder="Enter the text you want to convert to speech...", lines=5 ) speaker_id = gr.Number( label="Speaker ID", value=0, precision=0, minimum=0, maximum=10 ) with gr.Row(): max_duration = gr.Slider( label="Maximum Duration (ms)", minimum=1000, maximum=90000, value=30000, step=1000 ) # temperature = gr.Slider( # label="Temperature", # minimum=0.1, # maximum=1.5, # value=0.9, # step=0.1 # ) # top_k = gr.Slider( # label="Top-K", # minimum=1, # maximum=100, # value=50, # step=1 # ) generate_btn = gr.Button("Generate Audio") with gr.Column(): output_audio = gr.Audio(label="Output Audio", type="filepath", autoplay=True) with gr.Tab("Audio Generation with Context"): gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.") with gr.Row(): with gr.Column(): context_text1 = gr.Textbox(label="Context Text 1", lines=2) context_audio1 = gr.Audio(label="Context Audio 1", type="filepath") context_speaker1 = gr.Number(label="Speaker ID 1", value=0, precision=0) context_text2 = gr.Textbox(label="Context Text 2", lines=2) context_audio2 = gr.Audio(label="Context Audio 2", type="filepath") context_speaker2 = gr.Number(label="Speaker ID 2", value=1, precision=0) text_input_context = gr.Textbox( label="Text to convert to speech", placeholder="Enter the text you want to convert to speech...", lines=3 ) speaker_id_context = gr.Number( label="Speaker ID", value=0, precision=0 ) with gr.Row(): max_duration_context = gr.Slider( label="Maximum Duration (ms)", minimum=1000, maximum=90000, value=30000, step=1000 ) # temperature_context = gr.Slider( # label="Temperature", # minimum=0.1, # maximum=1.5, # value=0.9, # step=0.1 # ) # top_k_context = gr.Slider( # label="Top-K", # minimum=1, # maximum=100, # value=50, # step=1 # ) generate_context_btn = gr.Button("Generate Audio with Context") with gr.Column(): output_audio_context = gr.Audio(label="Output Audio", type="filepath", autoplay=True) # Add Hugging Face configuration tab with gr.Tab("Configuration"): gr.Markdown("### Hugging Face Token Configuration") gr.Markdown(""" To use the CSM-1B model, you need access to the model on Hugging Face. You can configure your token by: 1. Create a token at [Hugging Face Settings](https://huggingface.co/settings/tokens) 2. Set the `HF_TOKEN` environment variable with your token value Note: In Hugging Face Spaces, you can set environment variables in the Space Settings. """) hf_token_input = gr.Textbox( label="Hugging Face Token (Only for this session)", placeholder="Enter your token...", type="password" ) def set_token(token): if token: os.environ["HF_TOKEN"] = token login(token=token) return "Token set successfully! You can now load the model." return "Invalid token. Please enter a valid token." set_token_btn = gr.Button("Set Token") token_status = gr.Textbox(label="Status", interactive=False) set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status) # Add GPU information tab with gr.Tab("GPU Information"): gr.Markdown("### About ZeroGPU") gr.Markdown(""" This application uses Hugging Face Spaces' ZeroGPU to optimize GPU usage. ZeroGPU helps free up GPU memory when not in use, saving resources and improving performance. When you generate audio, the GPU will be used automatically and released after completion. Note: In the ZeroGPU environment, CUDA is not initialized in the main process, but only in functions with the @spaces.GPU decorator. """) gr.Markdown("### GPU Quota Information") gr.Markdown(""" Hugging Face Spaces has GPU quota limitations: - Each GPU operation has a default duration of 60 seconds - We've reduced this to 30 seconds for audio generation and 10 seconds for GPU checks - If you exceed your quota, you'll need to wait for it to reset (usually a few hours) - For better performance, try generating shorter audio clips If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again. """) # @spaces.GPU(duration=10) def check_gpu(): if torch.cuda.is_available(): gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) return f"GPU: {gpu_name}\nMemory: {gpu_memory:.2f} GB" else: return "No GPU found. The application will run on CPU." check_gpu_btn = gr.Button("Check GPU") gpu_info = gr.Textbox(label="GPU Information", interactive=False) check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info) # Add model loading button load_model_btn = gr.Button("Load Model") model_status = gr.Textbox(label="Model Status", interactive=False) # @spaces.GPU(duration=10) def load_model_and_report(): global model_loaded if model_loaded: return "Model has already been loaded!" else: initialize_model() return "Model loaded successfully!" load_model_btn.click(fn=load_model_and_report, inputs=None, outputs=model_status) # Connect components generate_btn.click( fn=generate_speech_simple, inputs=[ text_input, speaker_id, max_duration, # temperature, # top_k ], outputs=output_audio ) generate_context_btn.click( fn=generate_speech, inputs=[ text_input_context, speaker_id_context, context_audio1, context_text1, context_speaker1, context_audio2, context_text2, context_speaker2, max_duration_context, # temperature_context, # top_k_context ], outputs=output_audio_context ) return demo # Launch the application if __name__ == "__main__": demo = create_demo() demo.queue().launch()