import gradio as gr import torch from transformers import CsmForConditionalGeneration, AutoProcessor import os from datetime import datetime class DanishTTSInterface: def __init__(self, model_path="./model"): self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") # Load processor and model following CSM docs pattern self.processor = AutoProcessor.from_pretrained(model_path) self.model = CsmForConditionalGeneration.from_pretrained( model_path, device_map=self.device ) self.model.eval() def generate_speech(self, text, temperature=0.7, max_length=1024, speaker_id=0, do_sample=True, depth_decoder_temperature=0.7, depth_decoder_do_sample=True, top_k=50, top_p=0.9, repetition_penalty=1.0): """Generate speech from Danish text""" try: # Format text with speaker ID following CSM docs pattern formatted_text = f"[{speaker_id}]{text}" # Prepare inputs following CSM docs exactly inputs = self.processor(formatted_text, add_special_tokens=True).to(self.device) # Prepare generation parameters generation_kwargs = { "output_audio": True, "max_length": max_length, "temperature": temperature, "do_sample": do_sample, "depth_decoder_temperature": depth_decoder_temperature, "depth_decoder_do_sample": depth_decoder_do_sample, } # Add sampling parameters only if sampling is enabled if do_sample: generation_kwargs.update({ "top_k": int(top_k) if top_k > 0 else None, "top_p": top_p if top_p < 1.0 else None, "repetition_penalty": repetition_penalty }) # Generate audio following CSM docs pattern audio = self.model.generate(**inputs, **generation_kwargs) # Save audio using processor timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") temp_path = f"output_danish_{timestamp}.wav" self.processor.save_audio(audio, temp_path) return temp_path, f"Generated Danish speech for: '{text}'" except Exception as e: error_msg = f"Error generating speech: {str(e)}" print(error_msg) return None, error_msg def create_interface(): """Create and configure the Gradio interface""" # Initialize TTS model try: tts_model = DanishTTSInterface() print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") return None def calculate_auto_max_length(text, multiplier=1.0): """Calculate appropriate max length based on input text""" # Base calculation: roughly 4-6 tokens per character for Danish text # Plus generous extra tokens for audio generation text_tokens = len(text) * 5 # Add larger buffer for speaker tokens, special tokens, and audio generation buffer = 400 # Higher minimum viable length min_length = 256 # Calculate with adjustable safety margin calculated_length = max(min_length, int((text_tokens + buffer) * multiplier)) # Round to nearest 128 for cleaner values return ((calculated_length + 127) // 128) * 128 def tts_inference(text, temperature, auto_length, auto_multiplier, max_length, speaker_id, do_sample, depth_decoder_temperature, depth_decoder_do_sample, top_k, top_p, repetition_penalty): """Gradio interface function for TTS inference""" if not text.strip(): return None, "Please enter some Danish text to synthesize." # Determine max length based on toggle if auto_length: effective_max_length = calculate_auto_max_length(text, auto_multiplier) status_prefix = f"Auto max length: {effective_max_length} (multiplier: {auto_multiplier}). " else: effective_max_length = max_length status_prefix = f"Manual max length: {effective_max_length}. " audio_path, message = tts_model.generate_speech( text=text, temperature=temperature, max_length=effective_max_length, speaker_id=int(speaker_id), do_sample=do_sample, depth_decoder_temperature=depth_decoder_temperature, depth_decoder_do_sample=depth_decoder_do_sample, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty ) # Prepend length info to status message if audio_path: message = status_prefix + message return audio_path, message # Create Gradio interface using modern Blocks syntax with gr.Blocks( title="CSM-1B Danish Text-to-Speech" ) as interface: gr.Markdown("# CSM-1B Danish Text-to-Speech") gr.Markdown("Natural-sounding Danish speech synthesis with voice control. Authored by [Nicolaj Reck](https://www.linkedin.com/in/nicolaj-reck-053aa38a/)") gr.Markdown("") gr.Markdown("") with gr.Row(): with gr.Column(): gr.Markdown("### Input & Voice Settings") text_input = gr.Textbox( label="Danish Text", placeholder="Indtast dansk tekst her...", lines=3 ) speaker_id_input = gr.Radio( choices=[("Male", 0), ("Female", 1)], value=0, label="Speaker", info="Select voice gender" ) temperature_input = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Backbone Temperature", info="Controls creativity for main model" ) depth_decoder_temperature_input = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Depth Decoder Temperature", info="Controls creativity for depth decoder" ) auto_length_input = gr.Checkbox( value=True, label="Auto Max Length", info="Automatically adapt max length based on input text length" ) auto_length_multiplier = gr.Slider( minimum=0.5, maximum=2.5, value=1.0, step=0.1, label="Auto Length Multiplier", info="Adjust auto-calculated max length (1.0 = base calculation)" ) max_length_input = gr.Slider( minimum=56, maximum=2048, value=1024, step=64, label="Max Length (Manual)", info="Manual maximum sequence length (used when auto is disabled)", interactive=False # Start disabled when auto is enabled ) with gr.Column(): gr.Markdown("### Sampling Settings") do_sample_input = gr.Checkbox( value=True, label="Enable Sampling (Backbone)", info="Use sampling instead of greedy decoding" ) depth_decoder_do_sample_input = gr.Checkbox( value=True, label="Enable Sampling (Depth Decoder)", info="Use sampling for depth decoder" ) top_k_input = gr.Slider( minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to top K tokens (0 = disabled)" ) top_p_input = gr.Slider( minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus)", info="Cumulative probability threshold" ) repetition_penalty_input = gr.Slider( minimum=0.5, maximum=2.0, value=1.0, step=0.1, label="Repetition Penalty", info="Penalize repetitive tokens" ) generate_btn = gr.Button("Generate Speech", variant="primary", size="lg") with gr.Column(): gr.Markdown("### Output") audio_output = gr.Audio( label="Generated Speech" ) status_output = gr.Textbox( label="Status", lines=2 ) # Toggle max length slider and multiplier based on auto mode def toggle_auto_controls(auto_enabled): return [ gr.Slider(interactive=auto_enabled), # multiplier gr.Slider(interactive=not auto_enabled) # manual slider ] auto_length_input.change( fn=toggle_auto_controls, inputs=[auto_length_input], outputs=[auto_length_multiplier, max_length_input] ) # Set up the generation function generate_btn.click( fn=tts_inference, inputs=[ text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input, do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input, top_k_input, top_p_input, repetition_penalty_input ], outputs=[audio_output, status_output] ) gr.Markdown("") gr.Markdown("") # Add examples with consistent parameters gr.Examples( examples=[ ["Husk at gemme arbejdet, før computeren genstarter, ellers risikerer du at miste både filer og vigtige ændringer.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], ["Pakken leveres i morgen mellem 9 og 12, og du får en SMS-besked, så snart den er klar til afhentning.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], ["Vi gør opmærksom på, at toget mod Københavns Hovedbanegård er forsinket med omkring 15 minutter.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], ["Man får mest muligt ud af sin tid, og slipper for unødvendig stress, hvis man planlægger en rejse.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0] ], inputs=[ text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input, do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input, top_k_input, top_p_input, repetition_penalty_input ] ) return interface def main(): """Main function to launch the Gradio interface""" print("Starting CSM-1B Danish TTS Interface...") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") interface = create_interface() if interface is None: print("Failed to create interface. Please check your model path and dependencies.") return # Launch the interface interface.launch( server_name="0.0.0.0", server_port=7860, share=False, debug=True, show_error=True ) if __name__ == "__main__": main()