File size: 19,166 Bytes
50262ab
ce8a201
7c6ede0
575a6d7
 
025dfc0
575a6d7
7c6ede0
575a6d7
 
 
 
f4154c5
7c6ede0
 
b0b6988
7c6ede0
575a6d7
ce8a201
7c6ede0
025dfc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575a6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4154c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575a6d7
025dfc0
575a6d7
 
f4154c5
 
575a6d7
 
 
 
 
f4154c5
575a6d7
025dfc0
575a6d7
f4154c5
 
 
 
575a6d7
 
 
 
7e3fe0c
575a6d7
 
f4154c5
575a6d7
 
 
 
 
 
 
 
 
 
 
f4154c5
575a6d7
 
 
 
 
 
 
f4154c5
575a6d7
 
f4154c5
575a6d7
 
 
025dfc0
 
 
575a6d7
 
 
025dfc0
 
f4154c5
575a6d7
 
 
 
7e3fe0c
 
575a6d7
 
7e3fe0c
575a6d7
7e3fe0c
 
0e3aa4b
7e3fe0c
0e3aa4b
7e3fe0c
 
 
 
 
 
 
 
 
 
575a6d7
 
 
 
f4154c5
 
575a6d7
 
 
 
 
 
f4154c5
575a6d7
025dfc0
 
575a6d7
 
 
f4154c5
575a6d7
 
 
f4154c5
 
 
 
 
 
 
 
 
 
 
 
575a6d7
 
f4154c5
575a6d7
 
 
 
 
f4154c5
575a6d7
 
 
 
 
f4154c5
575a6d7
 
 
 
 
 
 
f4154c5
575a6d7
7e3fe0c
575a6d7
7e3fe0c
 
 
 
9d5c27d
7e3fe0c
 
 
 
 
 
 
575a6d7
 
 
 
 
 
 
 
 
7e3fe0c
575a6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c6ede0
 
575a6d7
2099990
575a6d7
 
 
 
 
2099990
117bfeb
575a6d7
 
 
7c6ede0
575a6d7
7c6ede0
575a6d7
7c6ede0
 
575a6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
025dfc0
 
 
575a6d7
 
 
 
 
 
 
 
 
 
 
 
 
 
f4154c5
 
 
575a6d7
 
 
 
 
 
7c6ede0
 
575a6d7
 
7c6ede0
575a6d7
 
025dfc0
f4154c5
575a6d7
 
7c6ede0
575a6d7
 
025dfc0
f4154c5
575a6d7
 
7c6ede0
575a6d7
 
 
 
 
7c6ede0
025dfc0
 
575a6d7
 
 
025dfc0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
from nemo.collections.asr.models import ASRModel
import torch
import gradio as gr
import spaces
import gc
import shutil
from pathlib import Path
from pydub import AudioSegment
import numpy as np
import os
import gradio.themes as gr_themes
import csv
import datetime

device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME="Quantamhash/Quantum_STT_V2.0"

model = ASRModel.from_pretrained(model_name=MODEL_NAME)
model.eval()


def start_session(request: gr.Request):
    session_hash = request.session_hash
    session_dir = Path(f'/tmp/{session_hash}')
    session_dir.mkdir(parents=True, exist_ok=True)

    print(f"Session with hash {session_hash} started.")
    return session_dir.as_posix()

def end_session(request: gr.Request):
    session_hash = request.session_hash
    session_dir = Path(f'/tmp/{session_hash}')
    
    if session_dir.exists():
        shutil.rmtree(session_dir)

    print(f"Session with hash {session_hash} ended.")

def get_audio_segment(audio_path, start_second, end_second):
    if not audio_path or not Path(audio_path).exists():
        print(f"Warning: Audio path '{audio_path}' not found or invalid for clipping.")
        return None
    try:
        start_ms = int(start_second * 1000)
        end_ms = int(end_second * 1000)

        start_ms = max(0, start_ms)
        if end_ms <= start_ms:
            print(f"Warning: End time ({end_second}s) is not after start time ({start_second}s). Adjusting end time.")
            end_ms = start_ms + 100

        audio = AudioSegment.from_file(audio_path)
        clipped_audio = audio[start_ms:end_ms]

        samples = np.array(clipped_audio.get_array_of_samples())
        if clipped_audio.channels == 2:
            samples = samples.reshape((-1, 2)).mean(axis=1).astype(samples.dtype)

        frame_rate = clipped_audio.frame_rate
        if frame_rate <= 0:
             print(f"Warning: Invalid frame rate ({frame_rate}) detected for clipped audio.")
             frame_rate = audio.frame_rate

        if samples.size == 0:
             print(f"Warning: Clipped audio resulted in empty samples array ({start_second}s to {end_second}s).")
             return None

        return (frame_rate, samples)
    except FileNotFoundError:
        print(f"Error: Audio file not found at path: {audio_path}")
        return None
    except Exception as e:
        print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
        return None

def format_srt_time(seconds: float) -> str:
    """Converts seconds to SRT time format HH:MM:SS,mmm using datetime.timedelta"""
    sanitized_total_seconds = max(0.0, seconds)
    delta = datetime.timedelta(seconds=sanitized_total_seconds)
    total_int_seconds = int(delta.total_seconds())

    hours = total_int_seconds // 3600
    remainder_seconds_after_hours = total_int_seconds % 3600
    minutes = remainder_seconds_after_hours // 60
    seconds_part = remainder_seconds_after_hours % 60
    milliseconds = delta.microseconds // 1000

    return f"{hours:02d}:{minutes:02d}:{seconds_part:02d},{milliseconds:03d}"

def generate_srt_content(segment_timestamps: list) -> str:
    """Generates SRT formatted string from segment timestamps."""
    srt_content = []
    for i, ts in enumerate(segment_timestamps):
        start_time = format_srt_time(ts['start'])
        end_time = format_srt_time(ts['end'])
        text = ts['segment']
        srt_content.append(str(i + 1))
        srt_content.append(f"{start_time} --> {end_time}")
        srt_content.append(text)
        srt_content.append("")
    return "\n".join(srt_content)

@spaces.GPU
def get_transcripts_and_raw_times(audio_path, session_dir):
    if not audio_path:
        gr.Error("No audio file path provided for transcription.", duration=None)
        # Return an update to hide the buttons
        return [], [], None, gr.DownloadButton(label="Download Transcript (CSV)", visible=False), gr.DownloadButton(label="Download Transcript (SRT)", visible=False)

    vis_data = [["N/A", "N/A", "Processing failed"]]
    raw_times_data = [[0.0, 0.0]]
    processed_audio_path = None
    csv_file_path = None
    srt_file_path = None
    original_path_name = Path(audio_path).name
    audio_name = Path(audio_path).stem

    # Initialize button states
    csv_button_update = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
    srt_button_update = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)

    try:
        try:
            gr.Info(f"Loading audio: {original_path_name}", duration=2)
            audio = AudioSegment.from_file(audio_path)
            duration_sec = audio.duration_seconds
        except Exception as load_e:
            gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
            return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

        resampled = False
        mono = False

        target_sr = 16000
        if audio.frame_rate != target_sr:
            try:
                audio = audio.set_frame_rate(target_sr)
                resampled = True
            except Exception as resample_e:
                 gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
                 return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

        if audio.channels == 2:
            try:
                audio = audio.set_channels(1)
                mono = True
            except Exception as mono_e:
                 gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
                 return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
        elif audio.channels > 2:
             gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
             return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

        if resampled or mono:
            try:
                processed_audio_path = Path(session_dir, f"{audio_name}_resampled.wav")
                audio.export(processed_audio_path, format="wav")
                transcribe_path = processed_audio_path.as_posix()
                info_path_name = f"{original_path_name} (processed)"
            except Exception as export_e:
                gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
                if processed_audio_path and os.path.exists(processed_audio_path):
                    os.remove(processed_audio_path)
                return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update
        else:
            transcribe_path = audio_path
            info_path_name = original_path_name

        # Flag to track if long audio settings were applied
        long_audio_settings_applied = False
        try:
            model.to(device)
            model.to(torch.float32)
            gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)

            # Check duration and apply specific settings for long audio
            if duration_sec > 480 : # 8 minutes
                try:
                    gr.Info("Audio longer than 8 minutes. Applying optimized settings for long transcription.", duration=3)
                    print("Applying long audio settings: Local Attention and Chunking.")
                    model.change_attention_model("rel_pos_local_attn", [256,256])
                    model.change_subsampling_conv_chunking_factor(1)  # 1 = auto select
                    long_audio_settings_applied = True
                except Exception as setting_e:
                    gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
                    print(f"Warning: Failed to apply long audio settings: {setting_e}")
                    # Proceed without long audio settings if applying them failed
            
            model.to(torch.bfloat16)
            output = model.transcribe([transcribe_path], timestamps=True)

            if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
                 gr.Error("Transcription failed or produced unexpected output format.", duration=None)
                 # Return an update to hide the buttons
                 return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

            segment_timestamps = output[0].timestamp['segment']
            csv_headers = ["Start (s)", "End (s)", "Segment"]
            vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
            raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]

            # CSV file generation
            try:
                csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
                writer = csv.writer(open(csv_file_path, 'w'))
                writer.writerow(csv_headers)
                writer.writerows(vis_data)
                print(f"CSV transcript saved to temporary file: {csv_file_path}")
                csv_button_update = gr.DownloadButton(value=csv_file_path, visible=True, label="Download Transcript (CSV)")
            except Exception as csv_e:
                gr.Error(f"Failed to create transcript CSV file: {csv_e}", duration=None)
                print(f"Error writing CSV: {csv_e}")

            if segment_timestamps:
                try:
                    srt_content = generate_srt_content(segment_timestamps)
                    srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
                    with open(srt_file_path, 'w', encoding='utf-8') as f:
                        f.write(srt_content)
                    print(f"SRT transcript saved to temporary file: {srt_file_path}")
                    srt_button_update = gr.DownloadButton(value=srt_file_path, visible=True, label="Download Transcript (SRT)")
                except Exception as srt_e:
                    gr.Warning(f"Failed to create transcript SRT file: {srt_e}", duration=5)
                    print(f"Error writing SRT: {srt_e}")

            gr.Info("Transcription complete.", duration=2)
            return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update

        except torch.cuda.OutOfMemoryError as e:
            error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
            print(f"CUDA OutOfMemoryError: {e}")
            gr.Error(error_msg, duration=None)
            return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

        except FileNotFoundError:
            error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
            print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
            gr.Error(error_msg, duration=None)
            return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], audio_path, csv_button_update, srt_button_update

        except Exception as e:
            error_msg = f"Transcription failed: {e}"
            print(f"Error during transcription processing: {e}")
            gr.Error(error_msg, duration=None)
            vis_data = [["Error", "Error", error_msg]]
            raw_times_data = [[0.0, 0.0]]
            return vis_data, raw_times_data, audio_path, csv_button_update, srt_button_update
        finally:
            # --- Model Cleanup ---
            try:
                # Revert settings if they were applied for long audio
                if long_audio_settings_applied:
                    try:
                        print("Reverting long audio settings.")
                        model.change_attention_model("rel_pos") 
                        model.change_subsampling_conv_chunking_factor(-1)
                        long_audio_settings_applied = False # Reset flag
                    except Exception as revert_e:
                        print(f"Warning: Failed to revert long audio settings: {revert_e}")
                        gr.Warning(f"Issue reverting model settings after long transcription: {revert_e}", duration=5)

                # Original cleanup
                if 'model' in locals() and hasattr(model, 'cpu'):
                     if device == 'cuda':
                          model.cpu()
                gc.collect()
                if device == 'cuda':
                    torch.cuda.empty_cache()
            except Exception as cleanup_e:
                print(f"Error during model cleanup: {cleanup_e}")
                gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
            # --- End Model Cleanup ---

    finally:
        if processed_audio_path and os.path.exists(processed_audio_path):
            try:
                os.remove(processed_audio_path)
                print(f"Temporary audio file {processed_audio_path} removed.")
            except Exception as e:
                print(f"Error removing temporary audio file {processed_audio_path}: {e}")

def play_segment(evt: gr.SelectData, raw_ts_list, current_audio_path):
    if not isinstance(raw_ts_list, list):
        print(f"Warning: raw_ts_list is not a list ({type(raw_ts_list)}). Cannot play segment.")
        return gr.Audio(value=None, label="Selected Segment")

    if not current_audio_path:
        print("No audio path available to play segment from.")
        return gr.Audio(value=None, label="Selected Segment")

    selected_index = evt.index[0]

    if selected_index < 0 or selected_index >= len(raw_ts_list):
         print(f"Invalid index {selected_index} selected for list of length {len(raw_ts_list)}.")
         return gr.Audio(value=None, label="Selected Segment")

    if not isinstance(raw_ts_list[selected_index], (list, tuple)) or len(raw_ts_list[selected_index]) != 2:
         print(f"Warning: Data at index {selected_index} is not in the expected format [start, end].")
         return gr.Audio(value=None, label="Selected Segment")

    start_time_s, end_time_s = raw_ts_list[selected_index]

    print(f"Attempting to play segment: {current_audio_path} from {start_time_s:.2f}s to {end_time_s:.2f}s")

    segment_data = get_audio_segment(current_audio_path, start_time_s, end_time_s)

    if segment_data:
        print("Segment data retrieved successfully.")
        return gr.Audio(value=segment_data, autoplay=True, label=f"Segment: {start_time_s:.2f}s - {end_time_s:.2f}s", interactive=False)
    else:
        print("Failed to get audio segment data.")
        return gr.Audio(value=None, label="Selected Segment")

article = (
    "<p style='font-size: 1.1em;'>"
    "This demo showcases <code><a href='https://huggingface.co/Quantamhash/Quantum_STT_V2.0'>Quantum_STT_V2.0</a></code>, a 600-million-parameter model designed for high-quality English speech recognition."
    "</p>"
    "<p><strong style='color: red; font-size: 1.2em;'>Key Features:</strong></p>"
    "<ul style='font-size: 1.1em;'>"
    "    <li>Automatic punctuation and capitalization</li>"
    "    <li>Accurate word-level timestamps (click on a segment in the table below to play it!)</li>"
    "    <li>Efficiently transcribes long audio segments (<strong>updated to support upto 3 hours</strong>)"
    "    <li>Robust performance on spoken numbers, and song lyrics transcription </li>"
    "</ul>"
    "<p style='font-size: 1.1em;'>"
    "This model is <strong>available for commercial and non-commercial use</strong>."
)

examples = [
    ["data/example-yt_saTD1u8PorI.mp3"],
]

# Define an NVIDIA-inspired theme
nvidia_theme = gr_themes.Default(
    primary_hue=gr_themes.Color(
        c50="#E6F1D9", # Lightest green
        c100="#CEE3B3",
        c200="#B5D58C",
        c300="#9CC766",
        c400="#84B940",
        c500="#76B900", # NVIDIA Green
        c600="#68A600",
        c700="#5A9200",
        c800="#4C7E00",
        c900="#3E6A00", # Darkest green
        c950="#2F5600"
    ),
    neutral_hue="gray", # Use gray for neutral elements
    font=[gr_themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
).set()

# Apply the custom theme
with gr.Blocks(theme=nvidia_theme) as demo:
    model_display_name = MODEL_NAME.split('/')[-1] if '/' in MODEL_NAME else MODEL_NAME
    gr.Markdown(f"<h1 style='text-align: center; margin: 0 auto;'>Speech Transcription with {model_display_name}</h1>")
    gr.HTML(article)

    current_audio_path_state = gr.State(None)
    raw_timestamps_list_state = gr.State([])

    session_dir = gr.State()
    demo.load(start_session, outputs=[session_dir])

    with gr.Tabs():
        with gr.TabItem("Audio File"):
            file_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio File")
            gr.Examples(examples=examples, inputs=[file_input], label="Example Audio Files (Click to Load)")
            file_transcribe_btn = gr.Button("Transcribe Uploaded File", variant="primary")
        
        with gr.TabItem("Microphone"):
            mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio")
            mic_transcribe_btn = gr.Button("Transcribe Microphone Input", variant="primary")

    gr.Markdown("---")
    gr.Markdown("<p><strong style='color: #FF0000; font-size: 1.2em;'>Transcription Results (Click row to play segment)</strong></p>")

    # Define the DownloadButton *before* the DataFrame
    with gr.Row():
        download_btn_csv = gr.DownloadButton(label="Download Transcript (CSV)", visible=False)
        download_btn_srt = gr.DownloadButton(label="Download Transcript (SRT)", visible=False)

    vis_timestamps_df = gr.DataFrame(
        headers=["Start (s)", "End (s)", "Segment"],
        datatype=["number", "number", "str"],
        wrap=True,
        label="Transcription Segments"
    )

    # selected_segment_player was defined after download_btn previously, keep it after df for layout
    selected_segment_player = gr.Audio(label="Selected Segment", interactive=False)

    mic_transcribe_btn.click(
        fn=get_transcripts_and_raw_times,
        inputs=[mic_input, session_dir],
        outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
        api_name="transcribe_mic"
    )

    file_transcribe_btn.click(
        fn=get_transcripts_and_raw_times,
        inputs=[file_input, session_dir],
        outputs=[vis_timestamps_df, raw_timestamps_list_state, current_audio_path_state, download_btn_csv, download_btn_srt],
        api_name="transcribe_file"
    )

    vis_timestamps_df.select(
        fn=play_segment,
        inputs=[raw_timestamps_list_state, current_audio_path_state],
        outputs=[selected_segment_player],
    )

    demo.unload(end_session)

if __name__ == "__main__":
    print("Launching Gradio Demo...")
    demo.queue()
    demo.launch()