nicolajreck commited on
Commit
2ccd15a
·
verified ·
1 Parent(s): d997d99

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +301 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CsmForConditionalGeneration, AutoProcessor
4
+ import os
5
+ from datetime import datetime
6
+
7
+ class DanishTTSInterface:
8
+ def __init__(self, model_path="./model"):
9
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {self.device}")
11
+
12
+ # Load processor and model following CSM docs pattern
13
+ self.processor = AutoProcessor.from_pretrained(model_path)
14
+ self.model = CsmForConditionalGeneration.from_pretrained(
15
+ model_path,
16
+ device_map=self.device
17
+ )
18
+
19
+ self.model.eval()
20
+
21
+ def generate_speech(self, text, temperature=0.7, max_length=1024, speaker_id=0,
22
+ do_sample=True, depth_decoder_temperature=0.7, depth_decoder_do_sample=True,
23
+ top_k=50, top_p=0.9, repetition_penalty=1.0):
24
+ """Generate speech from Danish text"""
25
+ try:
26
+ # Format text with speaker ID following CSM docs pattern
27
+ formatted_text = f"[{speaker_id}]{text}"
28
+
29
+ # Prepare inputs following CSM docs exactly
30
+ inputs = self.processor(formatted_text, add_special_tokens=True).to(self.device)
31
+
32
+ # Prepare generation parameters
33
+ generation_kwargs = {
34
+ "output_audio": True,
35
+ "max_length": max_length,
36
+ "temperature": temperature,
37
+ "do_sample": do_sample,
38
+ "depth_decoder_temperature": depth_decoder_temperature,
39
+ "depth_decoder_do_sample": depth_decoder_do_sample,
40
+ }
41
+
42
+ # Add sampling parameters only if sampling is enabled
43
+ if do_sample:
44
+ generation_kwargs.update({
45
+ "top_k": int(top_k) if top_k > 0 else None,
46
+ "top_p": top_p if top_p < 1.0 else None,
47
+ "repetition_penalty": repetition_penalty
48
+ })
49
+
50
+ # Generate audio following CSM docs pattern
51
+ audio = self.model.generate(**inputs, **generation_kwargs)
52
+
53
+ # Save audio using processor
54
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
55
+ temp_path = f"output_danish_{timestamp}.wav"
56
+ self.processor.save_audio(audio, temp_path)
57
+
58
+ return temp_path, f"Generated Danish speech for: '{text}'"
59
+
60
+ except Exception as e:
61
+ error_msg = f"Error generating speech: {str(e)}"
62
+ print(error_msg)
63
+ return None, error_msg
64
+
65
+ def create_interface():
66
+ """Create and configure the Gradio interface"""
67
+
68
+ # Initialize TTS model
69
+ try:
70
+ tts_model = DanishTTSInterface()
71
+ print("Model loaded successfully!")
72
+ except Exception as e:
73
+ print(f"Error loading model: {e}")
74
+ return None
75
+
76
+ def calculate_auto_max_length(text, multiplier=1.0):
77
+ """Calculate appropriate max length based on input text"""
78
+ # Base calculation: roughly 4-6 tokens per character for Danish text
79
+ # Plus generous extra tokens for audio generation
80
+ text_tokens = len(text) * 5
81
+ # Add larger buffer for speaker tokens, special tokens, and audio generation
82
+ buffer = 400
83
+ # Higher minimum viable length
84
+ min_length = 256
85
+ # Calculate with adjustable safety margin
86
+ calculated_length = max(min_length, int((text_tokens + buffer) * multiplier))
87
+ # Round to nearest 128 for cleaner values
88
+ return ((calculated_length + 127) // 128) * 128
89
+
90
+ def tts_inference(text, temperature, auto_length, auto_multiplier, max_length, speaker_id, do_sample,
91
+ depth_decoder_temperature, depth_decoder_do_sample, top_k, top_p, repetition_penalty):
92
+ """Gradio interface function for TTS inference"""
93
+ if not text.strip():
94
+ return None, "Please enter some Danish text to synthesize."
95
+
96
+ # Determine max length based on toggle
97
+ if auto_length:
98
+ effective_max_length = calculate_auto_max_length(text, auto_multiplier)
99
+ status_prefix = f"Auto max length: {effective_max_length} (multiplier: {auto_multiplier}). "
100
+ else:
101
+ effective_max_length = max_length
102
+ status_prefix = f"Manual max length: {effective_max_length}. "
103
+
104
+ audio_path, message = tts_model.generate_speech(
105
+ text=text,
106
+ temperature=temperature,
107
+ max_length=effective_max_length,
108
+ speaker_id=int(speaker_id),
109
+ do_sample=do_sample,
110
+ depth_decoder_temperature=depth_decoder_temperature,
111
+ depth_decoder_do_sample=depth_decoder_do_sample,
112
+ top_k=top_k,
113
+ top_p=top_p,
114
+ repetition_penalty=repetition_penalty
115
+ )
116
+
117
+ # Prepend length info to status message
118
+ if audio_path:
119
+ message = status_prefix + message
120
+
121
+ return audio_path, message
122
+
123
+ # Create Gradio interface using modern Blocks syntax
124
+ with gr.Blocks(
125
+ title="CSM-1B Danish Text-to-Speech"
126
+ ) as interface:
127
+ gr.Markdown("# CSM-1B Danish Text-to-Speech")
128
+ gr.Markdown("Natural-sounding Danish speech synthesis with voice control. Authored by [Nicolaj Reck](https://www.linkedin.com/in/nicolaj-reck-053aa38a/)")
129
+ gr.Markdown("")
130
+ gr.Markdown("")
131
+
132
+ with gr.Row():
133
+ with gr.Column():
134
+ gr.Markdown("### Input & Voice Settings")
135
+ text_input = gr.Textbox(
136
+ label="Danish Text",
137
+ placeholder="Indtast dansk tekst her...",
138
+ lines=3
139
+ )
140
+ speaker_id_input = gr.Radio(
141
+ choices=[("Male", 0), ("Female", 1)],
142
+ value=0,
143
+ label="Speaker",
144
+ info="Select voice gender"
145
+ )
146
+
147
+ temperature_input = gr.Slider(
148
+ minimum=0.0,
149
+ maximum=2.0,
150
+ value=0.7,
151
+ step=0.1,
152
+ label="Backbone Temperature",
153
+ info="Controls creativity for main model"
154
+ )
155
+ depth_decoder_temperature_input = gr.Slider(
156
+ minimum=0.0,
157
+ maximum=2.0,
158
+ value=0.7,
159
+ step=0.1,
160
+ label="Depth Decoder Temperature",
161
+ info="Controls creativity for depth decoder"
162
+ )
163
+ auto_length_input = gr.Checkbox(
164
+ value=True,
165
+ label="Auto Max Length",
166
+ info="Automatically adapt max length based on input text length"
167
+ )
168
+ auto_length_multiplier = gr.Slider(
169
+ minimum=0.5,
170
+ maximum=2.5,
171
+ value=1.0,
172
+ step=0.1,
173
+ label="Auto Length Multiplier",
174
+ info="Adjust auto-calculated max length (1.0 = base calculation)"
175
+ )
176
+ max_length_input = gr.Slider(
177
+ minimum=56,
178
+ maximum=2048,
179
+ value=1024,
180
+ step=64,
181
+ label="Max Length (Manual)",
182
+ info="Manual maximum sequence length (used when auto is disabled)",
183
+ interactive=False # Start disabled when auto is enabled
184
+ )
185
+
186
+ with gr.Column():
187
+ gr.Markdown("### Sampling Settings")
188
+ do_sample_input = gr.Checkbox(
189
+ value=True,
190
+ label="Enable Sampling (Backbone)",
191
+ info="Use sampling instead of greedy decoding"
192
+ )
193
+ depth_decoder_do_sample_input = gr.Checkbox(
194
+ value=True,
195
+ label="Enable Sampling (Depth Decoder)",
196
+ info="Use sampling for depth decoder"
197
+ )
198
+ top_k_input = gr.Slider(
199
+ minimum=0,
200
+ maximum=100,
201
+ value=50,
202
+ step=1,
203
+ label="Top-K",
204
+ info="Limit to top K tokens (0 = disabled)"
205
+ )
206
+ top_p_input = gr.Slider(
207
+ minimum=0.0,
208
+ maximum=1.0,
209
+ value=0.9,
210
+ step=0.05,
211
+ label="Top-P (Nucleus)",
212
+ info="Cumulative probability threshold"
213
+ )
214
+ repetition_penalty_input = gr.Slider(
215
+ minimum=0.5,
216
+ maximum=2.0,
217
+ value=1.0,
218
+ step=0.1,
219
+ label="Repetition Penalty",
220
+ info="Penalize repetitive tokens"
221
+ )
222
+
223
+ generate_btn = gr.Button("Generate Speech", variant="primary", size="lg")
224
+
225
+ with gr.Column():
226
+ gr.Markdown("### Output")
227
+ audio_output = gr.Audio(
228
+ label="Generated Speech"
229
+ )
230
+ status_output = gr.Textbox(
231
+ label="Status",
232
+ lines=2
233
+ )
234
+
235
+ # Toggle max length slider and multiplier based on auto mode
236
+ def toggle_auto_controls(auto_enabled):
237
+ return [
238
+ gr.Slider(interactive=auto_enabled), # multiplier
239
+ gr.Slider(interactive=not auto_enabled) # manual slider
240
+ ]
241
+
242
+ auto_length_input.change(
243
+ fn=toggle_auto_controls,
244
+ inputs=[auto_length_input],
245
+ outputs=[auto_length_multiplier, max_length_input]
246
+ )
247
+
248
+ # Set up the generation function
249
+ generate_btn.click(
250
+ fn=tts_inference,
251
+ inputs=[
252
+ text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input,
253
+ do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input,
254
+ top_k_input, top_p_input, repetition_penalty_input
255
+ ],
256
+ outputs=[audio_output, status_output]
257
+ )
258
+
259
+ gr.Markdown("")
260
+ gr.Markdown("")
261
+
262
+ # Add examples with consistent parameters
263
+ gr.Examples(
264
+ examples=[
265
+ ["Husk at gemme arbejdet, før computeren genstarter, ellers risikerer du at miste både filer og vigtige ændringer.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
266
+ ["Pakken leveres i morgen mellem 9 og 12, og du får en SMS-besked, så snart den er klar til afhentning.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
267
+ ["Vi gør opmærksom på, at toget mod Københavns Hovedbanegård er forsinket med omkring 15 minutter.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0],
268
+ ["Man får mest muligt ud af sin tid, og slipper for unødvendig stress, hvis man planlægger en rejse.", 0.96, True, 1.0, 1024, 1, True, 0.7, True, 50, 0.9, 1.0]
269
+ ],
270
+ inputs=[
271
+ text_input, temperature_input, auto_length_input, auto_length_multiplier, max_length_input, speaker_id_input,
272
+ do_sample_input, depth_decoder_temperature_input, depth_decoder_do_sample_input,
273
+ top_k_input, top_p_input, repetition_penalty_input
274
+ ]
275
+ )
276
+
277
+ return interface
278
+
279
+ def main():
280
+ """Main function to launch the Gradio interface"""
281
+ print("Starting CSM-1B Danish TTS Interface...")
282
+ print(f"PyTorch version: {torch.__version__}")
283
+ print(f"CUDA available: {torch.cuda.is_available()}")
284
+
285
+ interface = create_interface()
286
+
287
+ if interface is None:
288
+ print("Failed to create interface. Please check your model path and dependencies.")
289
+ return
290
+
291
+ # Launch the interface
292
+ interface.launch(
293
+ server_name="0.0.0.0",
294
+ server_port=7860,
295
+ share=False,
296
+ debug=True,
297
+ show_error=True
298
+ )
299
+
300
+ if __name__ == "__main__":
301
+ main()