|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import CsmForConditionalGeneration, AutoProcessor |
|
|
import os |
|
|
from datetime import datetime |
|
|
|
|
|
class DanishTTSInterface: |
|
|
def __init__(self, model_path="./model"): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(model_path) |
|
|
self.model = CsmForConditionalGeneration.from_pretrained( |
|
|
model_path, |
|
|
device_map=self.device |
|
|
) |
|
|
|
|
|
self.model.eval() |
|
|
|
|
|
def generate_speech(self, text, temperature=0.7, max_length=1024, speaker_id=0, |
|
|
do_sample=True, depth_decoder_temperature=0.7, depth_decoder_do_sample=True, |
|
|
top_k=50, top_p=0.9, repetition_penalty=1.0): |
|
|
"""Generate speech from Danish text""" |
|
|
try: |
|
|
|
|
|
formatted_text = f"[{speaker_id}]{text}" |
|
|
|
|
|
|
|
|
inputs = self.processor(formatted_text, add_special_tokens=True).to(self.device) |
|
|
|
|
|
|
|
|
generation_kwargs = { |
|
|
"output_audio": True, |
|
|
"max_length": max_length, |
|
|
"temperature": temperature, |
|
|
"do_sample": do_sample, |
|
|
"depth_decoder_temperature": depth_decoder_temperature, |
|
|
"depth_decoder_do_sample": depth_decoder_do_sample, |
|
|
} |
|
|
|
|
|
|
|
|
if do_sample: |
|
|
generation_kwargs.update({ |
|
|
"top_k": int(top_k) if top_k > 0 else None, |
|
|
"top_p": top_p if top_p < 1.0 else None, |
|
|
"repetition_penalty": repetition_penalty |
|
|
}) |
|
|
|
|
|
|
|
|
audio = self.model.generate(**inputs, **generation_kwargs) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
temp_path = f"output_danish_{timestamp}.wav" |
|
|
self.processor.save_audio(audio, temp_path) |
|
|
|
|
|
return temp_path, f"Generated Danish speech for: '{text}'" |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error generating speech: {str(e)}" |
|
|
print(error_msg) |
|
|
return None, error_msg |
|
|
|
|
|
def create_interface(): |
|
|
"""Create and configure the Gradio interface""" |
|
|
|
|
|
|
|
|
try: |
|
|
tts_model = DanishTTSInterface() |
|
|
print("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
return None |
|
|
|
|
|
def calculate_auto_max_length(text, multiplier=1.0): |
|
|
"""Calculate appropriate max length based on input text""" |
|
|
|
|
|
|
|
|
text_tokens = len(text) * 5 |
|
|
|
|
|
buffer = 400 |
|
|
|
|
|
min_length = 256 |
|
|
|
|
|
calculated_length = max(min_length, int((text_tokens + buffer) * multiplier)) |
|
|
|
|
|
return ((calculated_length + 127) // 128) * 128 |
|
|
|
|
|
def tts_inference(text, temperature, auto_length, auto_multiplier, max_length, speaker_id, do_sample, |
|
|
depth_decoder_temperature, depth_decoder_do_sample, top_k, top_p, repetition_penalty): |
|
|
"""Gradio interface function for TTS inference""" |
|
|
if not text.strip(): |
|
|
return None, "Please enter some Danish text to synthesize." |
|
|
|
|
|
|
|
|
if auto_length: |
|
|
effective_max_length = calculate_auto_max_length(text, auto_multiplier) |
|
|
status_prefix = f"Auto max length: {effective_max_length} (multiplier: {auto_multiplier}). " |
|
|
else: |
|
|
effective_max_length = max_length |
|
|
status_prefix = f"Manual max length: {effective_max_length}. " |
|
|
|
|
|
audio_path, message = tts_model.generate_speech( |
|
|
text=text, |
|
|
temperature=temperature, |
|
|
max_length=effective_max_length, |
|
|
speaker_id=int(speaker_id), |
|
|
do_sample=do_sample, |
|
|
depth_decoder_temperature=depth_decoder_temperature, |
|
|
depth_decoder_do_sample=depth_decoder_do_sample, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty |
|
|
) |
|
|
|
|
|
|
|
|
if audio_path: |
|
|
message = status_prefix + message |
|
|
|
|
|
return audio_path, message |
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="CSM-1B Danish Text-to-Speech" |
|
|
) as interface: |
|
|
gr.Markdown("# CSM-1B Danish Text-to-Speech") |
|
|
gr.Markdown("Natural-sounding Danish speech synthesis with voice control. Authored by [Nicolaj Reck](https://www.linkedin.com/in/nicolaj-reck-053aa38a/)") |
|
|
gr.Markdown("") |
|
|
gr.Markdown("") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.Markdown("### Input & Voice Settings") |
|
|
text_input = gr.Textbox( |
|
|
label="Danish Text", |
|
|
placeholder="Indtast dansk tekst her...", |
|
|
lines=3 |
|
|
) |
|
|
speaker_id_input = gr.Radio( |
|
|
choices=[("Male", 0), ("Female", 1)], |
|
|
value=0, |
|
|
label="Speaker", |
|
|
info="Select voice gender" |
|
|
) |
|
|
|
|
|
temperature_input = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Backbone Temperature", |
|
|
info="Controls creativity for main model" |
|
|
) |
|
|
depth_decoder_temperature_input = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=2.0, |
|
|
value=0.7, |
|
|
step=0.1, |
|
|
label="Depth Decoder Temperature", |
|
|
info="Controls creativity for depth decoder" |
|
|
) |
|
|
auto_length_input = gr.Checkbox( |
|
|
value=True, |
|
|
label="Auto Max Length", |
|
|
info="Automatically adapt max length based on input text length" |
|
|
) |
|
|
auto_length_multiplier = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.5, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Auto Length Multiplier", |
|
|
info="Adjust auto-calculated max length (1.0 = base calculation)" |
|
|
) |
|
|
max_length_input = gr.Slider( |
|
|
minimum=56, |
|
|
maximum=2048, |
|
|
value=1024, |
|
|
step=64, |
|
|
label="Max Length (Manual)", |
|
|
info="Manual maximum sequence length (used when auto is disabled)", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Sampling Settings") |
|
|
do_sample_input = gr.Checkbox( |
|
|
value=True, |
|
|
label="Enable Sampling (Backbone)", |
|
|
info="Use sampling instead of greedy decoding" |
|
|
) |
|
|
depth_decoder_do_sample_input = gr.Checkbox( |
|
|
value=True, |
|
|
label="Enable Sampling (Depth Decoder)", |
|
|
info="Use sampling for depth decoder" |
|
|
) |
|
|
top_k_input = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=1, |
|
|
label="Top-K", |
|
|
info="Limit to top K tokens (0 = disabled)" |
|
|
) |
|
|
top_p_input = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-P (Nucleus)", |
|
|
info="Cumulative probability threshold" |
|
|
) |
|
|
repetition_penalty_input = gr.Slider( |
|
|
minimum=0.5, |
|
|
maximum=2.0, |
|
|
value=1.0, |
|
|
step=0.1, |
|
|
label="Repetition Penalty", |
|
|
info="Penalize repetitive tokens" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate Speech", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Output") |
|
|
audio_output = gr.Audio( |
|
|
label="Generated Speech" |
|
|
) |
|
|
status_output = gr.Textbox( |
|
|
label="Status", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
|
|
|
def toggle_auto_controls(auto_enabled): |
|
|
return [ |
|
|
gr.Slider(interactive=auto_enabled), |
|
|
gr.Slider(interactive=not auto_enabled) |
|
|
] |
|
|
|
|
|
auto_length_input.change( |
|
|
fn=toggle_auto_controls, |
|
|
inputs=[auto_length_input], |
|
|
outputs=[auto_length_multiplier, max_length_input] |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=tts_inference, |
|
|
inputs=[ |
|
|
text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input, |
|
|
do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input, |
|
|
top_k_input, top_p_input, repetition_penalty_input |
|
|
], |
|
|
outputs=[audio_output, status_output] |
|
|
) |
|
|
|
|
|
gr.Markdown("") |
|
|
gr.Markdown("") |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Husk at gemme arbejdet, før computeren genstarter, ellers risikerer du at miste både filer og vigtige ændringer.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], |
|
|
["Pakken leveres i morgen mellem 9 og 12, og du får en SMS-besked, så snart den er klar til afhentning.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], |
|
|
["Vi gør opmærksom på, at toget mod Københavns Hovedbanegård er forsinket med omkring 15 minutter.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0], |
|
|
["Man får mest muligt ud af sin tid, og slipper for unødvendig stress, hvis man planlægger en rejse.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0] |
|
|
], |
|
|
inputs=[ |
|
|
text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input, |
|
|
do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input, |
|
|
top_k_input, top_p_input, repetition_penalty_input |
|
|
] |
|
|
) |
|
|
|
|
|
return interface |
|
|
|
|
|
def main(): |
|
|
"""Main function to launch the Gradio interface""" |
|
|
print("Starting CSM-1B Danish TTS Interface...") |
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
|
|
|
interface = create_interface() |
|
|
|
|
|
if interface is None: |
|
|
print("Failed to create interface. Please check your model path and dependencies.") |
|
|
return |
|
|
|
|
|
|
|
|
interface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False, |
|
|
debug=True, |
|
|
show_error=True |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |