nicolajreck's picture
Upload app.py with huggingface_hub
2ccd15a verified
raw
history blame
12.4 kB
import gradio as gr
import torch
from transformers import CsmForConditionalGeneration, AutoProcessor
import os
from datetime import datetime
class DanishTTSInterface:
def __init__(self, model_path="./model"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
# Load processor and model following CSM docs pattern
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = CsmForConditionalGeneration.from_pretrained(
model_path,
device_map=self.device
)
self.model.eval()
def generate_speech(self, text, temperature=0.7, max_length=1024, speaker_id=0,
do_sample=True, depth_decoder_temperature=0.7, depth_decoder_do_sample=True,
top_k=50, top_p=0.9, repetition_penalty=1.0):
"""Generate speech from Danish text"""
try:
# Format text with speaker ID following CSM docs pattern
formatted_text = f"[{speaker_id}]{text}"
# Prepare inputs following CSM docs exactly
inputs = self.processor(formatted_text, add_special_tokens=True).to(self.device)
# Prepare generation parameters
generation_kwargs = {
"output_audio": True,
"max_length": max_length,
"temperature": temperature,
"do_sample": do_sample,
"depth_decoder_temperature": depth_decoder_temperature,
"depth_decoder_do_sample": depth_decoder_do_sample,
}
# Add sampling parameters only if sampling is enabled
if do_sample:
generation_kwargs.update({
"top_k": int(top_k) if top_k > 0 else None,
"top_p": top_p if top_p < 1.0 else None,
"repetition_penalty": repetition_penalty
})
# Generate audio following CSM docs pattern
audio = self.model.generate(**inputs, **generation_kwargs)
# Save audio using processor
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_path = f"output_danish_{timestamp}.wav"
self.processor.save_audio(audio, temp_path)
return temp_path, f"Generated Danish speech for: '{text}'"
except Exception as e:
error_msg = f"Error generating speech: {str(e)}"
print(error_msg)
return None, error_msg
def create_interface():
"""Create and configure the Gradio interface"""
# Initialize TTS model
try:
tts_model = DanishTTSInterface()
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
return None
def calculate_auto_max_length(text, multiplier=1.0):
"""Calculate appropriate max length based on input text"""
# Base calculation: roughly 4-6 tokens per character for Danish text
# Plus generous extra tokens for audio generation
text_tokens = len(text) * 5
# Add larger buffer for speaker tokens, special tokens, and audio generation
buffer = 400
# Higher minimum viable length
min_length = 256
# Calculate with adjustable safety margin
calculated_length = max(min_length, int((text_tokens + buffer) * multiplier))
# Round to nearest 128 for cleaner values
return ((calculated_length + 127) // 128) * 128
def tts_inference(text, temperature, auto_length, auto_multiplier, max_length, speaker_id, do_sample,
depth_decoder_temperature, depth_decoder_do_sample, top_k, top_p, repetition_penalty):
"""Gradio interface function for TTS inference"""
if not text.strip():
return None, "Please enter some Danish text to synthesize."
# Determine max length based on toggle
if auto_length:
effective_max_length = calculate_auto_max_length(text, auto_multiplier)
status_prefix = f"Auto max length: {effective_max_length} (multiplier: {auto_multiplier}). "
else:
effective_max_length = max_length
status_prefix = f"Manual max length: {effective_max_length}. "
audio_path, message = tts_model.generate_speech(
text=text,
temperature=temperature,
max_length=effective_max_length,
speaker_id=int(speaker_id),
do_sample=do_sample,
depth_decoder_temperature=depth_decoder_temperature,
depth_decoder_do_sample=depth_decoder_do_sample,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty
)
# Prepend length info to status message
if audio_path:
message = status_prefix + message
return audio_path, message
# Create Gradio interface using modern Blocks syntax
with gr.Blocks(
title="CSM-1B Danish Text-to-Speech"
) as interface:
gr.Markdown("# CSM-1B Danish Text-to-Speech")
gr.Markdown("Natural-sounding Danish speech synthesis with voice control. Authored by [Nicolaj Reck](https://www.linkedin.com/in/nicolaj-reck-053aa38a/)")
gr.Markdown("")
gr.Markdown("")
with gr.Row():
with gr.Column():
gr.Markdown("### Input & Voice Settings")
text_input = gr.Textbox(
label="Danish Text",
placeholder="Indtast dansk tekst her...",
lines=3
)
speaker_id_input = gr.Radio(
choices=[("Male", 0), ("Female", 1)],
value=0,
label="Speaker",
info="Select voice gender"
)
temperature_input = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Backbone Temperature",
info="Controls creativity for main model"
)
depth_decoder_temperature_input = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Depth Decoder Temperature",
info="Controls creativity for depth decoder"
)
auto_length_input = gr.Checkbox(
value=True,
label="Auto Max Length",
info="Automatically adapt max length based on input text length"
)
auto_length_multiplier = gr.Slider(
minimum=0.5,
maximum=2.5,
value=1.0,
step=0.1,
label="Auto Length Multiplier",
info="Adjust auto-calculated max length (1.0 = base calculation)"
)
max_length_input = gr.Slider(
minimum=56,
maximum=2048,
value=1024,
step=64,
label="Max Length (Manual)",
info="Manual maximum sequence length (used when auto is disabled)",
interactive=False # Start disabled when auto is enabled
)
with gr.Column():
gr.Markdown("### Sampling Settings")
do_sample_input = gr.Checkbox(
value=True,
label="Enable Sampling (Backbone)",
info="Use sampling instead of greedy decoding"
)
depth_decoder_do_sample_input = gr.Checkbox(
value=True,
label="Enable Sampling (Depth Decoder)",
info="Use sampling for depth decoder"
)
top_k_input = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=1,
label="Top-K",
info="Limit to top K tokens (0 = disabled)"
)
top_p_input = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-P (Nucleus)",
info="Cumulative probability threshold"
)
repetition_penalty_input = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.0,
step=0.1,
label="Repetition Penalty",
info="Penalize repetitive tokens"
)
generate_btn = gr.Button("Generate Speech", variant="primary", size="lg")
with gr.Column():
gr.Markdown("### Output")
audio_output = gr.Audio(
label="Generated Speech"
)
status_output = gr.Textbox(
label="Status",
lines=2
)
# Toggle max length slider and multiplier based on auto mode
def toggle_auto_controls(auto_enabled):
return [
gr.Slider(interactive=auto_enabled), # multiplier
gr.Slider(interactive=not auto_enabled) # manual slider
]
auto_length_input.change(
fn=toggle_auto_controls,
inputs=[auto_length_input],
outputs=[auto_length_multiplier, max_length_input]
)
# Set up the generation function
generate_btn.click(
fn=tts_inference,
inputs=[
text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input,
do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input,
top_k_input, top_p_input, repetition_penalty_input
],
outputs=[audio_output, status_output]
)
gr.Markdown("")
gr.Markdown("")
# Add examples with consistent parameters
gr.Examples(
examples=[
["Husk at gemme arbejdet, før computeren genstarter, ellers risikerer du at miste både filer og vigtige ændringer.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
["Pakken leveres i morgen mellem 9 og 12, og du får en SMS-besked, så snart den er klar til afhentning.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
["Vi gør opmærksom på, at toget mod Københavns Hovedbanegård er forsinket med omkring 15 minutter.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
["Man får mest muligt ud af sin tid, og slipper for unødvendig stress, hvis man planlægger en rejse.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0]
],
inputs=[
text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input,
do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input,
top_k_input, top_p_input, repetition_penalty_input
]
)
return interface
def main():
"""Main function to launch the Gradio interface"""
print("Starting CSM-1B Danish TTS Interface...")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
interface = create_interface()
if interface is None:
print("Failed to create interface. Please check your model path and dependencies.")
return
# Launch the interface
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
debug=True,
show_error=True
)
if __name__ == "__main__":
main()