WizardForest commited on
Commit
8854b79
·
verified ·
1 Parent(s): a76c272

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ import torch
4
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration, AutomaticSpeechRecognitionPipeline
5
+ import numpy as np
6
+ import tempfile
7
+ import os
8
+
9
+ # 全域變數存儲模型
10
+ processor = None
11
+ model = None
12
+ asr_pipeline = None
13
+
14
+ def load_model():
15
+ """載入 Breeze ASR 25 模型"""
16
+ global processor, model, asr_pipeline
17
+
18
+ try:
19
+ processor = WhisperProcessor.from_pretrained("MediaTek-Research/Breeze-ASR-25")
20
+ model = WhisperForConditionalGeneration.from_pretrained("MediaTek-Research/Breeze-ASR-25")
21
+
22
+ # 檢查是否有 CUDA
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ model = model.to(device).eval()
25
+
26
+ # 建立 pipeline
27
+ asr_pipeline = AutomaticSpeechRecognitionPipeline(
28
+ model=model,
29
+ tokenizer=processor.tokenizer,
30
+ feature_extractor=processor.feature_extractor,
31
+ chunk_length_s=0
32
+ )
33
+
34
+ return f"✅ 模型載入成功!使用設備: {device}"
35
+ except Exception as e:
36
+ return f"❌ 模型載入失敗: {str(e)}"
37
+
38
+ def preprocess_audio(audio_path):
39
+ """音訊預處理"""
40
+ # 載入音訊
41
+ waveform, sample_rate = torchaudio.load(audio_path)
42
+
43
+ # 轉為單聲道
44
+ if waveform.shape[0] > 1:
45
+ waveform = waveform.mean(dim=0)
46
+
47
+ waveform = waveform.squeeze().numpy()
48
+
49
+ # 重採樣到 16kHz
50
+ if sample_rate != 16000:
51
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
52
+ waveform = resampler(torch.tensor(waveform)).numpy()
53
+
54
+ return waveform
55
+
56
+ def transcribe_audio(audio_input):
57
+ """語音辨識主函數"""
58
+ global asr_pipeline
59
+
60
+ try:
61
+ # 檢查模型是否已載入
62
+ if asr_pipeline is None:
63
+ status = load_model()
64
+ if "失敗" in status:
65
+ return status, "", "", ""
66
+
67
+ # 檢查音訊輸入
68
+ if audio_input is None:
69
+ return "❌ 請先上傳音訊檔案或進行錄音", "", "", ""
70
+
71
+ # 處理不同的音訊輸入格式
72
+ if isinstance(audio_input, str):
73
+ # 檔案路徑
74
+ audio_path = audio_input
75
+ elif isinstance(audio_input, tuple):
76
+ # Gradio 錄音格式 (sample_rate, audio_data)
77
+ sample_rate, audio_data = audio_input
78
+
79
+ # 建立臨時檔案
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
81
+ # 確保音訊數據格式正確
82
+ if audio_data.dtype != np.float32:
83
+ audio_data = audio_data.astype(np.float32)
84
+
85
+ # 正規化音訊
86
+ if audio_data.max() > 1.0:
87
+ audio_data = audio_data / 32768.0
88
+
89
+ # 儲存為 wav 檔案
90
+ torchaudio.save(tmp_file.name, torch.tensor(audio_data).unsqueeze(0), sample_rate)
91
+ audio_path = tmp_file.name
92
+ else:
93
+ return "❌ 不支援的音訊格式", "", "", ""
94
+
95
+ # 預處理音訊
96
+ waveform = preprocess_audio(audio_path)
97
+
98
+ # 執行語音辨識
99
+ result = asr_pipeline(waveform, return_timestamps=True)
100
+
101
+ # 清理臨時檔案
102
+ if isinstance(audio_input, tuple) and os.path.exists(audio_path):
103
+ os.unlink(audio_path)
104
+
105
+ # 格式化結果
106
+ transcription = result["text"].strip()
107
+
108
+ # 格式化時間戳記顯示
109
+ formatted_text = ""
110
+ pure_text = ""
111
+ srt_text = ""
112
+
113
+ if "chunks" in result and result["chunks"]:
114
+ for i, chunk in enumerate(result["chunks"], 1):
115
+ start_time = chunk["timestamp"][0] if chunk["timestamp"][0] is not None else 0
116
+ end_time = chunk["timestamp"][1] if chunk["timestamp"][1] is not None else 0
117
+ text = chunk['text'].strip()
118
+
119
+ if text: # 只處理非空文字
120
+ # 格式化顯示文字
121
+ #formatted_text += f"[{start_time:.2f}s - {end_time:.2f}s]: {text}\n"
122
+
123
+ # 純文字(不含時間戳記)
124
+ pure_text += f"{text}\n"
125
+
126
+ # SRT 格式
127
+ start_srt = f"{int(start_time//3600):02d}:{int((start_time%3600)//60):02d}:{int(start_time%60):02d},{int((start_time%1)*1000):03d}"
128
+ end_srt = f"{int(end_time//3600):02d}:{int((end_time%3600)//60):02d}:{int(end_time%60):02d},{int((end_time%1)*1000):03d}"
129
+ srt_text += f"{i}\n{start_srt} --> {end_srt}\n{text}\n\n"
130
+ else:
131
+ # 如果沒有時間戳記,只顯示文字
132
+ #formatted_text = transcription
133
+ pure_text = transcription
134
+ srt_text = f"1\n00:00:00,000 --> 00:00:10,000\n{transcription}\n\n"
135
+
136
+ return "✅ 辨識完成", pure_text.strip(), srt_text.strip()
137
+
138
+ except Exception as e:
139
+ return f"❌ 辨識過程發生錯誤: {str(e)}", ""
140
+
141
+ def clear_all():
142
+ """清除所有內容"""
143
+ return None, "🔄 已清除所有內容", "", "", ""
144
+
145
+ # 建立 Gradio 介面
146
+ with gr.Blocks(title="語音辨識系統", theme=gr.themes.Soft()) as demo:
147
+
148
+ gr.Markdown("""
149
+ # 🎤 語音辨識系統 - Breeze ASR 25
150
+
151
+ ### 功能特色:
152
+ - 🔧 使用 Breeze ASR 25 模型,專為繁體中文優化
153
+ - ⏰ 顯示時間戳記
154
+ - 🌐 強化中英混用辨識能力
155
+ - 感謝[MediaTek-Research/Breeze-ASR-25](https://huggingface.co/MediaTek-Research/Breeze-ASR-25)
156
+ """)
157
+
158
+ with gr.Row():
159
+ with gr.Column(scale=1):
160
+ # 音訊輸入區域
161
+ gr.Markdown("### 📂 音訊輸入(wav)")
162
+
163
+ with gr.Tab("檔案上傳"):
164
+ audio_file = gr.Audio(
165
+ label="上傳音訊檔案",
166
+ type="filepath",
167
+ format="wav"
168
+ )
169
+
170
+ with gr.Tab("即時錄音"):
171
+ audio_mic = gr.Audio(
172
+ label="點擊開始錄音",
173
+ type="numpy",
174
+ format="wav"
175
+ )
176
+
177
+ # 控制按鈕
178
+ with gr.Row():
179
+ transcribe_btn = gr.Button("🚀 開始辨識", variant="primary", size="lg")
180
+ clear_btn = gr.Button("🗑️ 清除", variant="secondary")
181
+
182
+ with gr.Column(scale=1):
183
+ # 狀態顯示
184
+ status_output = gr.Textbox(
185
+ label="📊 狀態",
186
+ placeholder="等待操作...",
187
+ interactive=False,
188
+ lines=2
189
+ )
190
+
191
+
192
+ # 純文字結果
193
+ pure_text_output = gr.Textbox(
194
+ label="📄 純文字結果",
195
+ placeholder="純文字結果...",
196
+ lines=4,
197
+ max_lines=10,
198
+ show_copy_button=True
199
+ )
200
+
201
+ # SRT 字幕格式
202
+ srt_output = gr.Textbox(
203
+ label="🎬 SRT 字幕格式",
204
+ placeholder="SRT 格式字幕...",
205
+ lines=6,
206
+ max_lines=15,
207
+ show_copy_button=True
208
+ )
209
+
210
+
211
+ # 修正事件綁定
212
+ def transcribe_wrapper(audio_file_val, audio_mic_val):
213
+ audio_input = audio_file_val if audio_file_val else audio_mic_val
214
+ return transcribe_audio(audio_input)
215
+
216
+ transcribe_btn.click(
217
+ fn=transcribe_wrapper,
218
+ inputs=[audio_file, audio_mic],
219
+ outputs=[status_output, pure_text_output, srt_output]
220
+ )
221
+
222
+ clear_btn.click(
223
+ fn=clear_all,
224
+ outputs=[audio_file, status_output, pure_text_output, srt_output]
225
+ )
226
+
227
+ # 啟動應用
228
+ if __name__ == "__main__":
229
+ demo.launch()