Commit
·
1bf36cc
1
Parent(s):
401b3f7
init main framework
Browse files- README.md +135 -7
- aggregator/README.md +112 -0
- aggregator/__init__.py +10 -0
- aggregator/semantic_aggregator.py +333 -29
- dataset/transcripts/test1_segment_1_20250423_201934.json +4 -4
- display/display.py +56 -1
- main.py +300 -0
- optimizer/dispatcher.py +118 -37
- optimizer/llm_api_runner.py +96 -16
- optimizer/optimize_task.py +65 -14
- transcribe/transcribe.py +28 -15
- translator/translator.py +95 -18
- vad/__init__.py +34 -1
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🎙️
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
@@ -8,17 +8,145 @@ app_file: app.py
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
-
#
|
12 |
|
13 |
-
|
14 |
|
15 |
-
##
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
```bash
|
23 |
curl -X POST https://your-space-name.hf.space/transcribe \
|
24 |
-F "[email protected]"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: 伪流式音频转写 + LLM优化系统
|
3 |
emoji: 🎙️
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
|
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# 伪流式音频转写 + LLM优化系统
|
12 |
|
13 |
+
这个项目实现了一个伪流式音频转写系统,包括VAD分段、Whisper转录、语义聚合、LLM优化和翻译等功能。系统采用模块化设计,各个组件可以独立工作,也可以组合使用。
|
14 |
|
15 |
+
## 系统架构
|
16 |
|
17 |
+
```mermaid
|
18 |
+
graph TD
|
19 |
+
A[音频流输入] --> B[VAD]
|
20 |
+
B --> C[Transcribe]
|
21 |
+
C --> D[语义聚合控制器]
|
22 |
|
23 |
+
D --> E[即时输出模块]
|
24 |
+
D --> F[LLM 优化调度器]
|
25 |
+
|
26 |
+
F --> G[优化后回填模块]
|
27 |
+
G --> E
|
28 |
+
E --> H[翻译模块]
|
29 |
+
```
|
30 |
+
|
31 |
+
## 主要模块
|
32 |
+
|
33 |
+
- **VAD分段器**: 根据能量、静音、说话边界等信号判断语音段落
|
34 |
+
- **Whisper转录模块**: 对每段VAD输出进行whisper转写,输出文本+时间戳
|
35 |
+
- **语义聚合控制器**: 维护segment缓冲池,判断是否组成完整语义单元,推送到下游
|
36 |
+
- **即时输出模块**: 将聚合后的转写结果立即显示给用户
|
37 |
+
- **LLM优化调度器**: 接收待优化句子,加入优化任务队列
|
38 |
+
- **优化后回填模块**: 对照原句编号,将LLM优化结果回填替换
|
39 |
+
- **翻译模块**: 接收所有来自即时输出模块的句子,将其翻译为目标语言
|
40 |
+
|
41 |
+
## 语义聚合控制器
|
42 |
+
|
43 |
+
语义聚合控制器是系统的核心模块,负责将多个音频片段的转录结果聚合成完整的语义单元(句子),并推送到下游模块(显示和翻译)。
|
44 |
+
|
45 |
+
### 主要功能
|
46 |
+
|
47 |
+
1. **维护转录片段缓冲池**:收集来自转录模块的片段,直到形成完整语义单元
|
48 |
+
2. **判断语义完整性**:使用ChatGPT进行few-shot学习,判断多个片段是否组成完整句子
|
49 |
+
3. **重新转录**:将多个片段的音频合并,进行整体重新转录,提高准确性
|
50 |
+
4. **推送到下游**:将聚合结果发送到显示模块和翻译模块
|
51 |
+
|
52 |
+
详细信息请参考 [aggregator/README.md](aggregator/README.md)。
|
53 |
+
|
54 |
+
## 使用示例
|
55 |
+
|
56 |
+
### 完整流程示例
|
57 |
+
|
58 |
+
```python
|
59 |
+
from vad.vad import VoiceActivityDetector
|
60 |
+
from transcribe.transcribe import AudioTranscriber
|
61 |
+
from display.display import OutputRenderer
|
62 |
+
from translator.translator import NLLBTranslator
|
63 |
+
from aggregator.semantic_aggregator import SemanticAggregator
|
64 |
+
|
65 |
+
# 初始化各个模块
|
66 |
+
vad = VoiceActivityDetector()
|
67 |
+
transcriber = AudioTranscriber(model="small", device="cuda")
|
68 |
+
renderer = OutputRenderer()
|
69 |
+
translator = NLLBTranslator()
|
70 |
+
|
71 |
+
# 回调函数
|
72 |
+
def display_callback(sentence_id, text, state):
|
73 |
+
renderer.display(sentence_id, text, state)
|
74 |
+
|
75 |
+
def translate_callback(sentence_id, text):
|
76 |
+
translation = translator.translate(text)
|
77 |
+
print(f"[翻译] 句子 {sentence_id}: {translation}")
|
78 |
+
|
79 |
+
# 初始化聚合器
|
80 |
+
aggregator = SemanticAggregator(
|
81 |
+
on_display=display_callback,
|
82 |
+
on_translate=translate_callback,
|
83 |
+
transcriber=transcriber
|
84 |
+
)
|
85 |
+
|
86 |
+
# 处理音频
|
87 |
+
audio_data, sample_rate = sf.read("audio.wav")
|
88 |
+
segments = vad.detect_voice_segments(audio_data, sample_rate)
|
89 |
+
|
90 |
+
for i, (start, end) in enumerate(segments):
|
91 |
+
segment_audio = audio_data[int(start * sample_rate):int(end * sample_rate)]
|
92 |
+
results = transcriber.transcribe_segment(segment_audio, start_time=start)
|
93 |
+
|
94 |
+
for result in results:
|
95 |
+
result.segment_index = i + 1
|
96 |
+
aggregator.add_segment(result)
|
97 |
+
|
98 |
+
# 最后强制刷新缓冲区
|
99 |
+
aggregator.flush(force=True)
|
100 |
+
```
|
101 |
+
|
102 |
+
更详细的示例请参考 [aggregator/integration_example.py](aggregator/integration_example.py)。
|
103 |
+
|
104 |
+
## API服务
|
105 |
+
|
106 |
+
系统也提供了REST API服务,可以通过HTTP请求进行音频转写。
|
107 |
+
|
108 |
+
### API端点
|
109 |
+
|
110 |
+
- `GET /` → 健康检查
|
111 |
+
- `POST /transcribe` → 上传`.wav/.mp3`文件并接收转写文本
|
112 |
+
|
113 |
+
### 使用示例 (cURL)
|
114 |
|
115 |
```bash
|
116 |
curl -X POST https://your-space-name.hf.space/transcribe \
|
117 |
-F "[email protected]"
|
118 |
+
```
|
119 |
+
|
120 |
+
## 安装与运行
|
121 |
+
|
122 |
+
### 环境要求
|
123 |
+
|
124 |
+
- Python 3.8+
|
125 |
+
- PyTorch 1.12+
|
126 |
+
- CUDA 11.6+ (如果使用GPU)
|
127 |
+
|
128 |
+
### 安装依赖
|
129 |
+
|
130 |
+
```bash
|
131 |
+
pip install -r requirements.txt
|
132 |
+
```
|
133 |
+
|
134 |
+
### 运行API服务
|
135 |
+
|
136 |
+
```bash
|
137 |
+
python app.py
|
138 |
+
```
|
139 |
+
|
140 |
+
### 运行集成示例
|
141 |
+
|
142 |
+
```bash
|
143 |
+
# 设置OpenAI API密钥(用于句子完整性判断)
|
144 |
+
export OPENAI_API_KEY=your_api_key
|
145 |
+
|
146 |
+
# 运行集成示例
|
147 |
+
python -m aggregator.integration_example
|
148 |
+
```
|
149 |
+
|
150 |
+
## 许可证
|
151 |
+
|
152 |
+
[MIT License](LICENSE)
|
aggregator/README.md
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 语义聚合控制器 (Semantic Aggregator)
|
2 |
+
|
3 |
+
语义聚合控制器是伪流式音频转写系统的核心模块,负责将多个音频片段的转录结果聚合成完整的语义单元(句子),并推送到下游模块(显示和翻译)。
|
4 |
+
|
5 |
+
## 主要功能
|
6 |
+
|
7 |
+
1. **维护转录片段缓冲池**:收集来自转录模块的片段,直到形成完整语义单元
|
8 |
+
2. **判断语义完整性**:使用ChatGPT进行few-shot学习,判断多个片段是否组成完整句子
|
9 |
+
3. **重新转录**:将多个片段的音频合并,进行整体重新转录,提高准确性
|
10 |
+
4. **推送到下游**:将聚合结果发送到显示模块和翻译模块
|
11 |
+
|
12 |
+
## 核心组件
|
13 |
+
|
14 |
+
### SentenceCompletionDetector
|
15 |
+
|
16 |
+
使用ChatGPT进行few-shot学习,判断文本是否是一个完整的句子。
|
17 |
+
|
18 |
+
```python
|
19 |
+
detector = SentenceCompletionDetector()
|
20 |
+
is_complete = detector.is_sentence_complete("你会学习到如何使用音频数据集") # False
|
21 |
+
is_complete = detector.is_sentence_complete("你会学习到如何使用音频数据集。") # True
|
22 |
+
```
|
23 |
+
|
24 |
+
### SemanticAggregator
|
25 |
+
|
26 |
+
主要聚合控制器,负责缓冲、判断、重新转录和推送。
|
27 |
+
|
28 |
+
```python
|
29 |
+
aggregator = SemanticAggregator(
|
30 |
+
on_display=display_callback, # 显示回调
|
31 |
+
on_translate=translate_callback, # 翻译回调
|
32 |
+
transcriber=transcriber, # 转录器实例
|
33 |
+
segments_dir="dataset/audio/segments", # 音频片段目录
|
34 |
+
max_window=5.0, # 最大聚合时长(秒)
|
35 |
+
max_segments=5, # 最大聚合片段数
|
36 |
+
min_gap=0.8, # 触发聚合的最小间隔(秒)
|
37 |
+
force_flush_timeout=3.0 # 强制flush超时时间(秒)
|
38 |
+
)
|
39 |
+
```
|
40 |
+
|
41 |
+
## 聚合判断逻辑
|
42 |
+
|
43 |
+
聚合器使用以下逻辑判断是否应该聚合并输出:
|
44 |
+
|
45 |
+
1. **语义完整性**:使用ChatGPT判断当前缓冲区中的文本是否形成完整句子
|
46 |
+
2. **时间间隔**:如果相邻片段之间的间隔超过阈值,认为是不同的语义单元
|
47 |
+
3. **最大窗口**:如果聚合的总时长超过阈值,强制聚合
|
48 |
+
4. **最大片段数**:如果聚合的片段数超过阈值,强制聚合
|
49 |
+
5. **超时机制**:如果长时间没有新片段,强制输出当前缓冲区内容
|
50 |
+
|
51 |
+
## 重新转录流程
|
52 |
+
|
53 |
+
1. 获取所有片段的音频数据
|
54 |
+
2. 合并音频数据
|
55 |
+
3. 使用转录器重新转录合并后的音频
|
56 |
+
4. 比较重新转录结果与原始聚合结果
|
57 |
+
5. 如果有差异,更新显示并发送到翻译模块
|
58 |
+
|
59 |
+
## 使用示例
|
60 |
+
|
61 |
+
```python
|
62 |
+
from display.display import OutputRenderer
|
63 |
+
from translator.translator import NLLBTranslator
|
64 |
+
from transcribe.transcribe import AudioTranscriber, TranscriptionResult
|
65 |
+
from aggregator.semantic_aggregator import SemanticAggregator
|
66 |
+
|
67 |
+
# 初始化各个模块
|
68 |
+
renderer = OutputRenderer()
|
69 |
+
translator = NLLBTranslator()
|
70 |
+
transcriber = AudioTranscriber(model="small", device="cuda")
|
71 |
+
|
72 |
+
# 回调函数
|
73 |
+
def display_callback(sentence_id, text, state):
|
74 |
+
renderer.display(sentence_id, text, state)
|
75 |
+
|
76 |
+
def translate_callback(sentence_id, text):
|
77 |
+
translation = translator.translate(text)
|
78 |
+
print(f"[翻译] 句子 {sentence_id}: {translation}")
|
79 |
+
|
80 |
+
# 初始化聚合器
|
81 |
+
aggregator = SemanticAggregator(
|
82 |
+
on_display=display_callback,
|
83 |
+
on_translate=translate_callback,
|
84 |
+
transcriber=transcriber
|
85 |
+
)
|
86 |
+
|
87 |
+
# 添加转录结果
|
88 |
+
for result in transcription_results:
|
89 |
+
aggregator.add_segment(result)
|
90 |
+
|
91 |
+
# 最后强制刷新缓冲区
|
92 |
+
aggregator.flush(force=True)
|
93 |
+
```
|
94 |
+
|
95 |
+
## 测试
|
96 |
+
|
97 |
+
可以使用 `test_aggregator.py` 脚本测试聚合器功能:
|
98 |
+
|
99 |
+
```bash
|
100 |
+
# 设置OpenAI API密钥
|
101 |
+
export OPENAI_API_KEY=your_api_key
|
102 |
+
|
103 |
+
# 运行测试脚本
|
104 |
+
python -m aggregator.test_aggregator
|
105 |
+
```
|
106 |
+
|
107 |
+
## 注意事项
|
108 |
+
|
109 |
+
1. 需要设置 `OPENAI_API_KEY` 环境变量才能使用ChatGPT进行句子完整性判断
|
110 |
+
2. 音频片段目录需要包含所有需要重新转录的音频文件
|
111 |
+
3. 转录器需要正确初始化,包括模型、设备和计算类型
|
112 |
+
4. 回调函数需要正确处理聚合结果,包括显示和翻译
|
aggregator/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
语义聚合控制器模块
|
3 |
+
|
4 |
+
这个模块负责将多个音频片段的转录结果聚合成完整的语义单元(句子),
|
5 |
+
并推送到下游模块(显示和翻译)。
|
6 |
+
"""
|
7 |
+
|
8 |
+
from .semantic_aggregator import SemanticAggregator, SentenceCompletionDetector
|
9 |
+
|
10 |
+
__all__ = ['SemanticAggregator', 'SentenceCompletionDetector']
|
aggregator/semantic_aggregator.py
CHANGED
@@ -1,47 +1,140 @@
|
|
1 |
-
from typing import List, Callable, Optional
|
2 |
-
from vad.audio_transcriber import TranscriptionResult
|
3 |
import uuid
|
4 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
class SemanticAggregator:
|
7 |
"""
|
8 |
语义聚合控制器
|
9 |
- 维护segment缓冲池
|
10 |
- 判断是否组成完整语义单元
|
11 |
-
- 推送到下游(display/
|
12 |
"""
|
13 |
|
14 |
def __init__(
|
15 |
self,
|
16 |
-
|
|
|
|
|
|
|
17 |
max_window: float = 5.0,
|
18 |
max_segments: int = 5,
|
19 |
min_gap: float = 0.8,
|
20 |
force_flush_timeout: float = 3.0
|
21 |
):
|
22 |
"""
|
23 |
-
:param
|
|
|
|
|
|
|
24 |
:param max_window: 最大聚合时长(秒)
|
25 |
:param max_segments: 最大聚合片段数
|
26 |
:param min_gap: 触发聚合的最小间隔(秒)
|
27 |
:param force_flush_timeout: 强制flush超时时间(秒)
|
28 |
"""
|
29 |
self.buffer: List[TranscriptionResult] = []
|
30 |
-
self.
|
|
|
|
|
|
|
31 |
self.max_window = max_window
|
32 |
self.max_segments = max_segments
|
33 |
self.min_gap = min_gap
|
34 |
self.force_flush_timeout = force_flush_timeout
|
35 |
self.last_flush_time = time.time()
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def add_segment(self, result: TranscriptionResult):
|
38 |
"""
|
39 |
新增转写片段到缓冲池,自动判断是否聚合
|
40 |
"""
|
41 |
self.buffer.append(result)
|
|
|
42 |
if self._should_aggregate():
|
43 |
self._aggregate_and_flush()
|
44 |
elif time.time() - self.last_flush_time > self.force_flush_timeout:
|
|
|
45 |
self.flush(force=True)
|
46 |
|
47 |
def flush(self, force: bool = False):
|
@@ -49,6 +142,7 @@ class SemanticAggregator:
|
|
49 |
强制输出当前聚合内容
|
50 |
"""
|
51 |
if self.buffer:
|
|
|
52 |
self._aggregate_and_flush()
|
53 |
self.last_flush_time = time.time()
|
54 |
|
@@ -58,49 +152,259 @@ class SemanticAggregator:
|
|
58 |
"""
|
59 |
if not self.buffer:
|
60 |
return False
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
63 |
return True
|
|
|
64 |
# 2. segment间隔
|
65 |
if len(self.buffer) >= 2:
|
66 |
gap = self.buffer[-1].start_time - self.buffer[-2].end_time
|
67 |
if gap > self.min_gap:
|
|
|
68 |
return True
|
|
|
69 |
# 3. 最大窗口/片段数
|
70 |
total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
|
71 |
-
if total_duration > self.max_window
|
|
|
|
|
|
|
|
|
72 |
return True
|
|
|
73 |
return False
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
def _aggregate_and_flush(self):
|
76 |
"""
|
77 |
聚合并推送到下游
|
78 |
"""
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
sentence_id = str(uuid.uuid4())
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
self.buffer.clear()
|
83 |
self.last_flush_time = time.time()
|
|
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if __name__ == "__main__":
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
from display.display import OutputRenderer
|
89 |
-
from
|
90 |
-
|
|
|
|
|
91 |
renderer = OutputRenderer()
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Callable, Optional, Dict, Tuple
|
|
|
2 |
import uuid
|
3 |
import time
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import soundfile as sf
|
7 |
+
import logging
|
8 |
+
from openai import OpenAI
|
9 |
+
from transcribe.transcribe import TranscriptionResult, AudioTranscriber
|
10 |
+
|
11 |
+
# 配置日志
|
12 |
+
def setup_logger(name, level=logging.INFO):
|
13 |
+
"""设置日志记录器"""
|
14 |
+
logger = logging.getLogger(name)
|
15 |
+
# 清除所有已有的handler,避免重复
|
16 |
+
if logger.handlers:
|
17 |
+
logger.handlers.clear()
|
18 |
+
|
19 |
+
# 添加新的handler
|
20 |
+
handler = logging.StreamHandler()
|
21 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
22 |
+
handler.setFormatter(formatter)
|
23 |
+
logger.addHandler(handler)
|
24 |
+
logger.setLevel(level)
|
25 |
+
# 禁止传播到父logger,避免重复日志
|
26 |
+
logger.propagate = False
|
27 |
+
return logger
|
28 |
+
|
29 |
+
# 创建日志记录器
|
30 |
+
logger = setup_logger("aggregator")
|
31 |
+
|
32 |
+
class SentenceCompletionDetector:
|
33 |
+
"""
|
34 |
+
使用ChatGPT判断句子是否完整
|
35 |
+
"""
|
36 |
+
def __init__(self, model="gpt-3.5-turbo"):
|
37 |
+
self.model = model
|
38 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
39 |
+
|
40 |
+
def build_prompt(self, text: str) -> str:
|
41 |
+
return (
|
42 |
+
"判断以下语句是否为一句话的结尾,如果是,返回 True,否则返回 False:\n"
|
43 |
+
"\"你会学习到如何使用音频数据集,包括音频数据加载\"\n"
|
44 |
+
"False\n\n"
|
45 |
+
"\"你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法\"\n"
|
46 |
+
"True\n\n"
|
47 |
+
"\"在开始学习之前,我们需要\"\n"
|
48 |
+
"False\n\n"
|
49 |
+
"\"在开始学习之前,我们需要了解一些基本概念\"\n"
|
50 |
+
"True\n\n"
|
51 |
+
"\"第一章,介绍基础知识\"\n"
|
52 |
+
"True\n\n"
|
53 |
+
f"\"{text}\"\n"
|
54 |
+
)
|
55 |
+
|
56 |
+
def is_sentence_complete(self, text: str) -> bool:
|
57 |
+
"""
|
58 |
+
判断文本是否是一个完整的句子
|
59 |
+
"""
|
60 |
+
# # 简单规则:如果以标点符号结尾,认为是完整的句子
|
61 |
+
# if text.strip() and text.strip()[-1] in "。!?!?.;;":
|
62 |
+
# return True
|
63 |
+
|
64 |
+
# 使用ChatGPT进行更复杂的判断
|
65 |
+
prompt = self.build_prompt(text)
|
66 |
+
try:
|
67 |
+
response = self.client.chat.completions.create(
|
68 |
+
model=self.model,
|
69 |
+
messages=[
|
70 |
+
{"role": "system", "content": "你是一个语言专家,擅长判断句子是否完整。"},
|
71 |
+
{"role": "user", "content": prompt}
|
72 |
+
],
|
73 |
+
temperature=0.1,
|
74 |
+
max_tokens=10,
|
75 |
+
)
|
76 |
+
result = response.choices[0].message.content.strip()
|
77 |
+
logger.debug(f"ChatGPT判断结果: {result}")
|
78 |
+
return result.lower() == "true"
|
79 |
+
except Exception as e:
|
80 |
+
logger.error(f"调用ChatGPT出错: {str(e)}")
|
81 |
+
# 出错时使用简单规则判断
|
82 |
+
return len(text) > 20 # 如果文本较长,可能是完整句子
|
83 |
|
84 |
class SemanticAggregator:
|
85 |
"""
|
86 |
语义聚合控制器
|
87 |
- 维护segment缓冲池
|
88 |
- 判断是否组成完整语义单元
|
89 |
+
- 推送到下游(display/translator)
|
90 |
"""
|
91 |
|
92 |
def __init__(
|
93 |
self,
|
94 |
+
on_display: Callable[[str, str, str], None],
|
95 |
+
on_translate: Callable[[str, str], None],
|
96 |
+
transcriber: AudioTranscriber,
|
97 |
+
segments_dir: str = "dataset/audio/segments",
|
98 |
max_window: float = 5.0,
|
99 |
max_segments: int = 5,
|
100 |
min_gap: float = 0.8,
|
101 |
force_flush_timeout: float = 3.0
|
102 |
):
|
103 |
"""
|
104 |
+
:param on_display: 显示回调 (sentence_id, text, state)
|
105 |
+
:param on_translate: 翻译回调 (sentence_id, text)
|
106 |
+
:param transcriber: 转录器实例
|
107 |
+
:param segments_dir: 音频片段目录
|
108 |
:param max_window: 最大聚合时长(秒)
|
109 |
:param max_segments: 最大聚合片段数
|
110 |
:param min_gap: 触发聚合的最小间隔(秒)
|
111 |
:param force_flush_timeout: 强制flush超时时间(秒)
|
112 |
"""
|
113 |
self.buffer: List[TranscriptionResult] = []
|
114 |
+
self.on_display = on_display
|
115 |
+
self.on_translate = on_translate
|
116 |
+
self.transcriber = transcriber
|
117 |
+
self.segments_dir = segments_dir
|
118 |
self.max_window = max_window
|
119 |
self.max_segments = max_segments
|
120 |
self.min_gap = min_gap
|
121 |
self.force_flush_timeout = force_flush_timeout
|
122 |
self.last_flush_time = time.time()
|
123 |
+
self.sentence_detector = SentenceCompletionDetector()
|
124 |
+
self.audio_cache: Dict[int, np.ndarray] = {} # 缓存音频数据,避免重复读取
|
125 |
+
self.sample_rate = 16000 # 假设采样率为16kHz
|
126 |
+
logger.debug(f"语义聚合器初始化完成,参数: max_window={max_window}, max_segments={max_segments}")
|
127 |
|
128 |
def add_segment(self, result: TranscriptionResult):
|
129 |
"""
|
130 |
新增转写片段到缓冲池,自动判断是否聚合
|
131 |
"""
|
132 |
self.buffer.append(result)
|
133 |
+
logger.debug(f"添加片段: {result.text}")
|
134 |
if self._should_aggregate():
|
135 |
self._aggregate_and_flush()
|
136 |
elif time.time() - self.last_flush_time > self.force_flush_timeout:
|
137 |
+
logger.debug(f"超时强制刷新: {self.force_flush_timeout}秒")
|
138 |
self.flush(force=True)
|
139 |
|
140 |
def flush(self, force: bool = False):
|
|
|
142 |
强制输出当前聚合内容
|
143 |
"""
|
144 |
if self.buffer:
|
145 |
+
logger.debug(f"强制刷新缓冲区,当前片段数: {len(self.buffer)}")
|
146 |
self._aggregate_and_flush()
|
147 |
self.last_flush_time = time.time()
|
148 |
|
|
|
152 |
"""
|
153 |
if not self.buffer:
|
154 |
return False
|
155 |
+
|
156 |
+
# 1. 使用ChatGPT判断是否是完整句子
|
157 |
+
# 使用逗号连接segments,与_aggregate_and_flush保持一致
|
158 |
+
segments = [seg.text for seg in self.buffer]
|
159 |
+
combined_text = ",".join(segments)
|
160 |
+
if self.sentence_detector.is_sentence_complete(combined_text):
|
161 |
+
logger.info(f"检测到完整句子: {combined_text}")
|
162 |
return True
|
163 |
+
|
164 |
# 2. segment间隔
|
165 |
if len(self.buffer) >= 2:
|
166 |
gap = self.buffer[-1].start_time - self.buffer[-2].end_time
|
167 |
if gap > self.min_gap:
|
168 |
+
logger.info(f"检测到较大间隔: {gap:.2f}秒")
|
169 |
return True
|
170 |
+
|
171 |
# 3. 最大窗口/片段数
|
172 |
total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
|
173 |
+
if total_duration > self.max_window:
|
174 |
+
logger.info(f"达到最大时间窗口: {total_duration:.2f}秒")
|
175 |
+
return True
|
176 |
+
if len(self.buffer) >= self.max_segments:
|
177 |
+
logger.info(f"达到最大片段数: {len(self.buffer)}")
|
178 |
return True
|
179 |
+
|
180 |
return False
|
181 |
|
182 |
+
def _get_segment_audio(self, segment_index: int) -> np.ndarray:
|
183 |
+
"""
|
184 |
+
获取指定索引的音频片段数据
|
185 |
+
"""
|
186 |
+
if segment_index in self.audio_cache:
|
187 |
+
return self.audio_cache[segment_index]
|
188 |
+
|
189 |
+
# 读取音频文件
|
190 |
+
audio_path = os.path.join(self.segments_dir, f"test1_segment_{segment_index}.wav")
|
191 |
+
try:
|
192 |
+
audio_data, sample_rate = sf.read(audio_path)
|
193 |
+
self.audio_cache[segment_index] = audio_data
|
194 |
+
logger.debug(f"读取音频片段: {audio_path}, 长度: {len(audio_data)/sample_rate:.2f}秒")
|
195 |
+
return audio_data
|
196 |
+
except Exception as e:
|
197 |
+
logger.error(f"读取音频文件失败: {audio_path}, 错误: {str(e)}")
|
198 |
+
return np.array([])
|
199 |
+
|
200 |
+
def _combine_audio_segments(self, segment_indices: List[int]) -> Tuple[np.ndarray, float]:
|
201 |
+
"""
|
202 |
+
合并多个音频片段
|
203 |
+
返回: (合并后的音频数据, 起始时间)
|
204 |
+
"""
|
205 |
+
if not segment_indices:
|
206 |
+
return np.array([]), 0.0
|
207 |
+
|
208 |
+
# 获取所有片段的音频数据
|
209 |
+
audio_segments = []
|
210 |
+
for idx in segment_indices:
|
211 |
+
audio_data = self._get_segment_audio(idx)
|
212 |
+
if len(audio_data) > 0:
|
213 |
+
audio_segments.append(audio_data)
|
214 |
+
|
215 |
+
if not audio_segments:
|
216 |
+
return np.array([]), 0.0
|
217 |
+
|
218 |
+
# 合并音频数据
|
219 |
+
combined_audio = np.concatenate(audio_segments)
|
220 |
+
|
221 |
+
# 获取第一个片段的起始时间
|
222 |
+
first_segment = self.buffer[0]
|
223 |
+
start_time = first_segment.start_time
|
224 |
+
|
225 |
+
logger.debug(f"合并音频片段: {segment_indices}, 总长度: {len(combined_audio)/self.sample_rate:.2f}秒")
|
226 |
+
return combined_audio, start_time
|
227 |
+
|
228 |
+
def _retranscribe_segments(self, segment_indices: List[int]) -> List[TranscriptionResult]:
|
229 |
+
"""
|
230 |
+
重新转录合并后的音频片段
|
231 |
+
"""
|
232 |
+
combined_audio, start_time = self._combine_audio_segments(segment_indices)
|
233 |
+
if len(combined_audio) == 0:
|
234 |
+
logger.warning("没有有效的音频数据可以重新转录")
|
235 |
+
return []
|
236 |
+
|
237 |
+
logger.debug(f"重新转录合并的音频片段, 长度: {len(combined_audio)/self.sample_rate:.2f}秒")
|
238 |
+
try:
|
239 |
+
results = self.transcriber.transcribe_segment(combined_audio, start_time=start_time)
|
240 |
+
logger.debug(f"重新转录结果: {len(results)}条")
|
241 |
+
return results
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"重新转录失败: {str(e)}")
|
244 |
+
return []
|
245 |
+
|
246 |
def _aggregate_and_flush(self):
|
247 |
"""
|
248 |
聚合并推送到下游
|
249 |
"""
|
250 |
+
if not self.buffer:
|
251 |
+
return
|
252 |
+
|
253 |
+
# 获取所有片段的索引
|
254 |
+
segment_indices = []
|
255 |
+
for seg in self.buffer:
|
256 |
+
if hasattr(seg, 'segment_index') and seg.segment_index is not None:
|
257 |
+
if isinstance(seg.segment_index, list):
|
258 |
+
segment_indices.extend(seg.segment_index)
|
259 |
+
else:
|
260 |
+
segment_indices.append(seg.segment_index)
|
261 |
+
|
262 |
+
# 去重并排序
|
263 |
+
segment_indices = sorted(list(set(segment_indices)))
|
264 |
+
|
265 |
+
# 生成句子ID
|
266 |
sentence_id = str(uuid.uuid4())
|
267 |
+
|
268 |
+
# 1. 先使用原始文本进行输出,在segment之间添加逗号
|
269 |
+
original_segments = [seg.text for seg in self.buffer]
|
270 |
+
# 使用逗号连接segments,但不在最后添加句号
|
271 |
+
original_text = ",".join(original_segments)
|
272 |
+
logger.info(f"原始聚合文本: {original_text}")
|
273 |
+
self.on_display(sentence_id, original_text, "raw")
|
274 |
+
|
275 |
+
# 2. 重新转录
|
276 |
+
if segment_indices:
|
277 |
+
retranscribed_results = self._retranscribe_segments(segment_indices)
|
278 |
+
if retranscribed_results:
|
279 |
+
# 合并重新转录的结果,在segment之间添加逗号
|
280 |
+
retranscribed_segments = [res.text for res in retranscribed_results]
|
281 |
+
retranscribed_text = ",".join(retranscribed_segments)
|
282 |
+
logger.info(f"重新转录文本: {retranscribed_text}")
|
283 |
+
|
284 |
+
# 如果重新转录的结果与原始文本不同,则更新显示
|
285 |
+
if retranscribed_text != original_text:
|
286 |
+
self.on_display(sentence_id, retranscribed_text, "optimized")
|
287 |
+
|
288 |
+
# 发送到翻译模块
|
289 |
+
self.on_translate(sentence_id, retranscribed_text)
|
290 |
+
else:
|
291 |
+
# 如果重新转录失败,使用原始文本进行翻译
|
292 |
+
logger.warning("重新转录失败,使用原始文本进行翻译")
|
293 |
+
self.on_translate(sentence_id, original_text)
|
294 |
+
else:
|
295 |
+
# 如果没有有效的片段索引,使用原始文本进行翻译
|
296 |
+
logger.warning("没有有效的片段索引,使用原始文本进行翻译")
|
297 |
+
self.on_translate(sentence_id, original_text)
|
298 |
+
|
299 |
+
# 清空缓冲区
|
300 |
+
buffer_size = len(self.buffer)
|
301 |
self.buffer.clear()
|
302 |
self.last_flush_time = time.time()
|
303 |
+
logger.debug(f"清空缓冲区,释放 {buffer_size} 个片段")
|
304 |
|
305 |
|
306 |
+
def load_transcription_results(json_path):
|
307 |
+
"""从JSON文件加载转录结果"""
|
308 |
+
import json
|
309 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
310 |
+
data = json.load(f)
|
311 |
+
|
312 |
+
results = []
|
313 |
+
for segment in data['segments']:
|
314 |
+
result = TranscriptionResult(
|
315 |
+
text=segment['text'],
|
316 |
+
start_time=segment['start_time'],
|
317 |
+
end_time=segment['end_time'],
|
318 |
+
confidence=segment['confidence'],
|
319 |
+
verified=segment['verified'],
|
320 |
+
verified_text=segment['verified_text'],
|
321 |
+
verification_notes=segment['verification_notes'],
|
322 |
+
segment_index=segment['segment_index'] if 'segment_index' in segment else None
|
323 |
+
)
|
324 |
+
results.append(result)
|
325 |
+
|
326 |
+
return results
|
327 |
+
|
328 |
if __name__ == "__main__":
|
329 |
+
"""测试聚合器功能"""
|
330 |
+
import os
|
331 |
+
import sys
|
332 |
+
import json
|
333 |
+
from pathlib import Path
|
334 |
+
|
335 |
+
# 配置日志级别
|
336 |
+
logger.setLevel(logging.DEBUG)
|
337 |
+
|
338 |
+
# 检查OpenAI API密钥
|
339 |
+
if not os.getenv("OPENAI_API_KEY"):
|
340 |
+
logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
|
341 |
+
|
342 |
+
# 初始化各个模块
|
343 |
from display.display import OutputRenderer
|
344 |
+
from translator.translator import NLLBTranslator
|
345 |
+
from transcribe.transcribe import AudioTranscriber
|
346 |
+
|
347 |
+
# 初始化显示器
|
348 |
renderer = OutputRenderer()
|
349 |
+
|
350 |
+
# 初始化转录器
|
351 |
+
try:
|
352 |
+
transcriber = AudioTranscriber(model="small", device="cuda", compute_type="int8")
|
353 |
+
logger.info("使用GPU进行转录")
|
354 |
+
except Exception as e:
|
355 |
+
logger.warning(f"GPU初始化失败,使用CPU: {str(e)}")
|
356 |
+
transcriber = AudioTranscriber(model="small", device="cpu", compute_type="float32")
|
357 |
+
|
358 |
+
# 初始化翻译器(可选)
|
359 |
+
try:
|
360 |
+
translator = NLLBTranslator()
|
361 |
+
translation_enabled = True
|
362 |
+
except Exception as e:
|
363 |
+
logger.warning(f"翻译器初始化失败: {str(e)}")
|
364 |
+
translation_enabled = False
|
365 |
+
|
366 |
+
# 回调函数
|
367 |
+
def display_callback(sentence_id, text, state):
|
368 |
+
renderer.display(sentence_id, text, state)
|
369 |
+
|
370 |
+
def translate_callback(sentence_id, text):
|
371 |
+
if translation_enabled:
|
372 |
+
try:
|
373 |
+
translation = translator.translate(text)
|
374 |
+
logger.info(f"[翻译] 句子 {sentence_id}: {translation}")
|
375 |
+
except Exception as e:
|
376 |
+
logger.error(f"翻译失败: {str(e)}")
|
377 |
+
else:
|
378 |
+
logger.info(f"[翻译已禁用] 句子 {sentence_id}: {text}")
|
379 |
+
|
380 |
+
# 初始化聚合器
|
381 |
+
aggregator = SemanticAggregator(
|
382 |
+
on_display=display_callback,
|
383 |
+
on_translate=translate_callback,
|
384 |
+
transcriber=transcriber,
|
385 |
+
segments_dir="dataset/audio/segments",
|
386 |
+
max_window=10.0, # 增大窗口以便测试
|
387 |
+
max_segments=10, # 增大片段数以便测试
|
388 |
+
force_flush_timeout=5.0 # 增大超时以便测试
|
389 |
+
)
|
390 |
+
|
391 |
+
# 加载测试数据
|
392 |
+
test_file = "dataset/transcripts/test1_segment_1_20250423_201934.json"
|
393 |
+
try:
|
394 |
+
results = load_transcription_results(test_file)
|
395 |
+
logger.info(f"加载了 {len(results)} 条转录结果")
|
396 |
+
except Exception as e:
|
397 |
+
logger.error(f"加载转录结果失败: {str(e)}")
|
398 |
+
sys.exit(1)
|
399 |
+
|
400 |
+
# 模拟添加转录结果
|
401 |
+
for i, result in enumerate(results):
|
402 |
+
logger.info(f"添加第 {i+1}/{len(results)} 条转录结果: {result.text}")
|
403 |
+
aggregator.add_segment(result)
|
404 |
+
# 模拟处理延迟
|
405 |
+
# time.sleep(0.5)
|
406 |
+
|
407 |
+
# 强制刷新缓冲区
|
408 |
+
aggregator.flush(force=True)
|
409 |
+
|
410 |
+
logger.info("测试完成")
|
dataset/transcripts/test1_segment_1_20250423_201934.json
CHANGED
@@ -12,7 +12,7 @@
|
|
12 |
"verification_notes": null
|
13 |
},
|
14 |
{
|
15 |
-
"text": "
|
16 |
"start_time": 4.34,
|
17 |
"end_time": 5.56,
|
18 |
"confidence": 0.4482421875,
|
@@ -84,7 +84,7 @@
|
|
84 |
"verification_notes": null
|
85 |
},
|
86 |
{
|
87 |
-
"text": "
|
88 |
"start_time": 26.28,
|
89 |
"end_time": 28.28,
|
90 |
"confidence": 0.732666015625,
|
@@ -111,7 +111,7 @@
|
|
111 |
"verification_notes": null
|
112 |
},
|
113 |
{
|
114 |
-
"text": "
|
115 |
"start_time": 33.54,
|
116 |
"end_time": 36.5,
|
117 |
"confidence": 0.88739013671875,
|
@@ -138,7 +138,7 @@
|
|
138 |
"verification_notes": null
|
139 |
},
|
140 |
{
|
141 |
-
"text": "
|
142 |
"start_time": 40.86,
|
143 |
"end_time": 42.4,
|
144 |
"confidence": 0.609619140625,
|
|
|
12 |
"verification_notes": null
|
13 |
},
|
14 |
{
|
15 |
+
"text": "音频数据处理",
|
16 |
"start_time": 4.34,
|
17 |
"end_time": 5.56,
|
18 |
"confidence": 0.4482421875,
|
|
|
84 |
"verification_notes": null
|
85 |
},
|
86 |
{
|
87 |
+
"text": "包括波形,采样率和频谱图",
|
88 |
"start_time": 26.28,
|
89 |
"end_time": 28.28,
|
90 |
"confidence": 0.732666015625,
|
|
|
111 |
"verification_notes": null
|
112 |
},
|
113 |
{
|
114 |
+
"text": "高效加载大规模音频数据集的流式加载方法。",
|
115 |
"start_time": 33.54,
|
116 |
"end_time": 36.5,
|
117 |
"confidence": 0.88739013671875,
|
|
|
138 |
"verification_notes": null
|
139 |
},
|
140 |
{
|
141 |
+
"text": "基础的音频相关术语",
|
142 |
"start_time": 40.86,
|
143 |
"end_time": 42.4,
|
144 |
"confidence": 0.609619140625,
|
display/display.py
CHANGED
@@ -1,32 +1,87 @@
|
|
|
|
|
|
|
|
|
|
1 |
from rich.console import Console
|
2 |
from rich.text import Text
|
3 |
from typing import Literal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
|
|
|
|
|
|
|
|
5 |
console = Console()
|
6 |
|
7 |
class OutputRenderer:
|
|
|
|
|
|
|
|
|
|
|
8 |
def __init__(self):
|
|
|
9 |
self.history = {} # 用于更新同一条句子的优化内容
|
|
|
10 |
|
11 |
def display(self, sentence_id: str, text: str, state: Literal["raw", "optimized"]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
if state == "raw":
|
13 |
styled_text = Text(text, style="dim") # 灰色表示原始输出
|
|
|
14 |
elif state == "optimized":
|
15 |
styled_text = Text(text, style="bold black") # 深黑色加粗
|
|
|
16 |
else:
|
17 |
-
|
|
|
18 |
|
19 |
# 打印新内容(或替换历史)
|
20 |
if sentence_id in self.history:
|
21 |
console.print(f"[更新] 句子 {sentence_id}:", styled_text)
|
|
|
22 |
else:
|
23 |
console.print(f"[输出] 句子 {sentence_id}:", styled_text)
|
|
|
24 |
|
|
|
25 |
self.history[sentence_id] = text
|
|
|
|
|
26 |
|
27 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
28 |
renderer = OutputRenderer()
|
29 |
|
|
|
30 |
renderer.display("s1", "I think we should start the meeting now.", "raw")
|
|
|
31 |
# 模拟优化回填
|
32 |
renderer.display("s1", "I believe it's time to begin the meeting.", "optimized")
|
|
|
1 |
+
"""
|
2 |
+
显示模块 - 负责将转写结果显示给用户
|
3 |
+
"""
|
4 |
+
|
5 |
from rich.console import Console
|
6 |
from rich.text import Text
|
7 |
from typing import Literal
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# 配置日志
|
11 |
+
def setup_logger(name, level=logging.INFO):
|
12 |
+
"""设置日志记录器"""
|
13 |
+
logger = logging.getLogger(name)
|
14 |
+
# 清除所有已有的handler,避免重复
|
15 |
+
if logger.handlers:
|
16 |
+
logger.handlers.clear()
|
17 |
+
|
18 |
+
# 添加新的handler
|
19 |
+
handler = logging.StreamHandler()
|
20 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
21 |
+
handler.setFormatter(formatter)
|
22 |
+
logger.addHandler(handler)
|
23 |
+
logger.setLevel(level)
|
24 |
+
# 禁止传播到父logger,避免重复日志
|
25 |
+
logger.propagate = False
|
26 |
+
return logger
|
27 |
|
28 |
+
# 创建日志记录器
|
29 |
+
logger = setup_logger("display")
|
30 |
+
|
31 |
+
# 创建控制台对象
|
32 |
console = Console()
|
33 |
|
34 |
class OutputRenderer:
|
35 |
+
"""
|
36 |
+
输出渲染器,负责将转写结果显示给用户
|
37 |
+
支持原始文本和优化后文本的不同样式显示
|
38 |
+
"""
|
39 |
+
|
40 |
def __init__(self):
|
41 |
+
"""初始化输出渲染器"""
|
42 |
self.history = {} # 用于更新同一条句子的优化内容
|
43 |
+
logger.debug("输出渲染器初始化完成")
|
44 |
|
45 |
def display(self, sentence_id: str, text: str, state: Literal["raw", "optimized"]):
|
46 |
+
"""
|
47 |
+
显示转写结果
|
48 |
+
|
49 |
+
:param sentence_id: 句子ID
|
50 |
+
:param text: 文本内容
|
51 |
+
:param state: 状态,raw表示原始文本,optimized表示优化后文本
|
52 |
+
"""
|
53 |
if state == "raw":
|
54 |
styled_text = Text(text, style="dim") # 灰色表示原始输出
|
55 |
+
logger.debug(f"显示原始文本: {sentence_id}")
|
56 |
elif state == "optimized":
|
57 |
styled_text = Text(text, style="bold black") # 深黑色加粗
|
58 |
+
logger.debug(f"显示优化文本: {sentence_id}")
|
59 |
else:
|
60 |
+
logger.error(f"未知的输出状态: {state}")
|
61 |
+
raise ValueError(f"未知的输出状态: {state}")
|
62 |
|
63 |
# 打印新内容(或替换历史)
|
64 |
if sentence_id in self.history:
|
65 |
console.print(f"[更新] 句子 {sentence_id}:", styled_text)
|
66 |
+
logger.info(f"更新句子: {sentence_id}")
|
67 |
else:
|
68 |
console.print(f"[输出] 句子 {sentence_id}:", styled_text)
|
69 |
+
logger.info(f"输出句子: {sentence_id}")
|
70 |
|
71 |
+
# 记录历史
|
72 |
self.history[sentence_id] = text
|
73 |
+
logger.debug(f"句子内容: {text}")
|
74 |
+
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
+
# 设置日志级别为DEBUG以查看详细信息
|
78 |
+
logger.setLevel(logging.DEBUG)
|
79 |
+
|
80 |
+
# 测试代码
|
81 |
renderer = OutputRenderer()
|
82 |
|
83 |
+
# 显示原始文本
|
84 |
renderer.display("s1", "I think we should start the meeting now.", "raw")
|
85 |
+
|
86 |
# 模拟优化回填
|
87 |
renderer.display("s1", "I believe it's time to begin the meeting.", "optimized")
|
main.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
伪流式音频转写 + LLM优化系统 主程序
|
3 |
+
|
4 |
+
这个程序实现了完整的音频处理流水线,包括:
|
5 |
+
1. VAD分段
|
6 |
+
2. Whisper转录
|
7 |
+
3. 语义聚合
|
8 |
+
4. 即时输出
|
9 |
+
5. LLM优化
|
10 |
+
6. 翻译
|
11 |
+
|
12 |
+
使用方法:
|
13 |
+
python main.py [--audio_path AUDIO_PATH] [--use_gpu] [--enable_translation] [--enable_optimization]
|
14 |
+
"""
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
import time
|
19 |
+
import logging
|
20 |
+
import argparse
|
21 |
+
import numpy as np
|
22 |
+
import soundfile as sf
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import List, Dict, Optional, Tuple, Union
|
25 |
+
import uuid
|
26 |
+
|
27 |
+
# 配置日志
|
28 |
+
def setup_logger(name, level=logging.INFO):
|
29 |
+
"""设置日志记录器"""
|
30 |
+
logger = logging.getLogger(name)
|
31 |
+
# 清除所有已有的handler,避免重复
|
32 |
+
if logger.handlers:
|
33 |
+
logger.handlers.clear()
|
34 |
+
|
35 |
+
# 添加新的handler
|
36 |
+
handler = logging.StreamHandler()
|
37 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
38 |
+
handler.setFormatter(formatter)
|
39 |
+
logger.addHandler(handler)
|
40 |
+
logger.setLevel(level)
|
41 |
+
# 禁止传播到父logger,避免重复日志
|
42 |
+
logger.propagate = False
|
43 |
+
return logger
|
44 |
+
|
45 |
+
# 创建主日志记录器
|
46 |
+
logger = setup_logger("main")
|
47 |
+
|
48 |
+
# 导入各个模块
|
49 |
+
from vad import VoiceActivityDetector
|
50 |
+
from transcribe.transcribe import AudioTranscriber, TranscriptionResult
|
51 |
+
from aggregator.semantic_aggregator import SemanticAggregator
|
52 |
+
from display.display import OutputRenderer
|
53 |
+
from optimizer.dispatcher import OptimizationDispatcher
|
54 |
+
from translator.translator import NLLBTranslator
|
55 |
+
|
56 |
+
class AudioProcessingPipeline:
|
57 |
+
"""完整的音频处理流水线"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
audio_path: str,
|
62 |
+
use_gpu: bool = True,
|
63 |
+
enable_translation: bool = True,
|
64 |
+
enable_optimization: bool = True,
|
65 |
+
whisper_model: str = "large",
|
66 |
+
log_level: Union[int, str] = logging.INFO
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
初始化处理流水线
|
70 |
+
|
71 |
+
:param audio_path: 音频文件路径
|
72 |
+
:param use_gpu: 是否使用GPU
|
73 |
+
:param enable_translation: 是否启用翻译
|
74 |
+
:param enable_optimization: 是否启用LLM优化
|
75 |
+
:param whisper_model: Whisper模型大小 (tiny, base, small, medium, large)
|
76 |
+
:param log_level: 日志级别
|
77 |
+
"""
|
78 |
+
# 设置日志级别
|
79 |
+
if isinstance(log_level, str):
|
80 |
+
log_level = getattr(logging, log_level.upper())
|
81 |
+
logger.setLevel(log_level)
|
82 |
+
|
83 |
+
self.audio_path = audio_path
|
84 |
+
self.use_gpu = use_gpu
|
85 |
+
self.enable_translation = enable_translation
|
86 |
+
self.enable_optimization = enable_optimization
|
87 |
+
self.whisper_model = whisper_model
|
88 |
+
|
89 |
+
# 检查设备
|
90 |
+
self.device = "cuda" if use_gpu and self._is_gpu_available() else "cpu"
|
91 |
+
logger.info(f"使用设备: {self.device}")
|
92 |
+
logger.debug(f"配置: whisper_model={whisper_model}, translation={enable_translation}, optimization={enable_optimization}")
|
93 |
+
|
94 |
+
# 初始化各个模块
|
95 |
+
self._init_modules()
|
96 |
+
|
97 |
+
# 加载音频数据
|
98 |
+
self.audio_data, self.sample_rate = sf.read(audio_path)
|
99 |
+
logger.info(f"加载音频: {os.path.basename(audio_path)}, 长度: {len(self.audio_data)/self.sample_rate:.2f}秒")
|
100 |
+
logger.debug(f"音频详情: 采样率={self.sample_rate}Hz, 形状={self.audio_data.shape}")
|
101 |
+
|
102 |
+
# 存储句子ID到优化任务的映射
|
103 |
+
self.optimization_tasks: Dict[str, str] = {}
|
104 |
+
|
105 |
+
def _is_gpu_available(self) -> bool:
|
106 |
+
"""检查GPU是否可用"""
|
107 |
+
try:
|
108 |
+
import torch
|
109 |
+
if torch.cuda.is_available():
|
110 |
+
logger.debug(f"检测到GPU: {torch.cuda.get_device_name(0)}")
|
111 |
+
return True
|
112 |
+
return False
|
113 |
+
except ImportError:
|
114 |
+
logger.debug("未检测到PyTorch,将使用CPU")
|
115 |
+
return False
|
116 |
+
|
117 |
+
def _init_modules(self):
|
118 |
+
"""初始化各个处理模块"""
|
119 |
+
# 1. 初始化VAD
|
120 |
+
logger.debug("初始化VAD模块...")
|
121 |
+
self.vad = VoiceActivityDetector(save_audio=True, save_json=True)
|
122 |
+
|
123 |
+
# 2. 初始化转录器
|
124 |
+
logger.debug(f"初始化Whisper转录模块 (model={self.whisper_model}, device={self.device})...")
|
125 |
+
self.transcriber = AudioTranscriber(
|
126 |
+
model=self.whisper_model,
|
127 |
+
device=self.device,
|
128 |
+
compute_type="int8" if self.device == "cuda" else "float32"
|
129 |
+
)
|
130 |
+
|
131 |
+
# 3. 初始化显示器
|
132 |
+
logger.debug("初始化显示模块...")
|
133 |
+
self.renderer = OutputRenderer()
|
134 |
+
|
135 |
+
# 4. 初始化优化调度器(如果启用)
|
136 |
+
if self.enable_optimization:
|
137 |
+
logger.debug("初始化LLM优化调度器...")
|
138 |
+
self.optimizer = OptimizationDispatcher(
|
139 |
+
max_workers=2,
|
140 |
+
callback=self._optimization_callback
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
logger.debug("LLM优化已禁用")
|
144 |
+
self.optimizer = None
|
145 |
+
|
146 |
+
# 5. 初始化翻译器(如果启用)
|
147 |
+
if self.enable_translation:
|
148 |
+
logger.debug("初始化翻译模块...")
|
149 |
+
try:
|
150 |
+
self.translator = NLLBTranslator()
|
151 |
+
self.translation_enabled = True
|
152 |
+
except Exception as e:
|
153 |
+
logger.warning(f"翻译器初始化失败: {str(e)}")
|
154 |
+
self.translation_enabled = False
|
155 |
+
else:
|
156 |
+
logger.debug("翻译已禁用")
|
157 |
+
self.translation_enabled = False
|
158 |
+
self.translator = None
|
159 |
+
|
160 |
+
# 6. 初始化聚合器
|
161 |
+
logger.debug("初始化语义聚合控制器...")
|
162 |
+
self.aggregator = SemanticAggregator(
|
163 |
+
on_display=self._display_callback,
|
164 |
+
on_translate=self._translate_callback,
|
165 |
+
transcriber=self.transcriber,
|
166 |
+
segments_dir="dataset/audio/segments",
|
167 |
+
max_window=5.0,
|
168 |
+
max_segments=5,
|
169 |
+
min_gap=0.8,
|
170 |
+
force_flush_timeout=3.0
|
171 |
+
)
|
172 |
+
|
173 |
+
logger.info("所有模块初始化完成")
|
174 |
+
|
175 |
+
def _display_callback(self, sentence_id: str, text: str, state: str):
|
176 |
+
"""显示回调函数"""
|
177 |
+
self.renderer.display(sentence_id, text, state)
|
178 |
+
|
179 |
+
# 如果启用了优化,且是原始文本,则提交优化任务
|
180 |
+
if self.enable_optimization and state == "raw" and self.optimizer:
|
181 |
+
logger.debug(f"提交优化任务: {sentence_id}")
|
182 |
+
self.optimizer.submit(sentence_id, text)
|
183 |
+
|
184 |
+
def _translate_callback(self, sentence_id: str, text: str):
|
185 |
+
"""翻译回调函数"""
|
186 |
+
if self.translation_enabled and self.translator:
|
187 |
+
try:
|
188 |
+
# 翻译模块内部已经记录了原文和结果,这里只需调用翻译方法
|
189 |
+
self.translator.translate(text)
|
190 |
+
logger.debug(f"已翻译句子: {sentence_id}")
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"翻译失败: {str(e)}")
|
193 |
+
|
194 |
+
def _optimization_callback(self, sentence_id: str, original_text: str, optimized_text: str):
|
195 |
+
"""优化回调函数"""
|
196 |
+
logger.debug(f"收到优化结果: {sentence_id}")
|
197 |
+
# 更新显示
|
198 |
+
self.renderer.display(sentence_id, optimized_text, "optimized")
|
199 |
+
# 如果启用了翻译,则翻译优化后的文本
|
200 |
+
if self.translation_enabled:
|
201 |
+
logger.debug(f"翻译优化后的文本: {sentence_id}")
|
202 |
+
self._translate_callback(sentence_id, optimized_text)
|
203 |
+
|
204 |
+
def process(self):
|
205 |
+
"""处理音频文件"""
|
206 |
+
logger.info("开始处理音频...")
|
207 |
+
|
208 |
+
# 1. VAD分段
|
209 |
+
logger.debug("执行VAD分段...")
|
210 |
+
segments = self.vad.detect_voice_segments(self.audio_data, self.sample_rate)
|
211 |
+
logger.info(f"VAD分段完成: {len(segments)}个片段")
|
212 |
+
|
213 |
+
# 2. 转录每个片段
|
214 |
+
for i, (start, end) in enumerate(segments):
|
215 |
+
logger.debug(f"转录片段 {i+1}/{len(segments)}: {start:.2f}s -> {end:.2f}s")
|
216 |
+
|
217 |
+
# 提取片段音频数据
|
218 |
+
segment_audio = self.audio_data[int(start * self.sample_rate):int(end * self.sample_rate)]
|
219 |
+
|
220 |
+
# 转录片段
|
221 |
+
results = self.transcriber.transcribe_segment(segment_audio, start_time=start)
|
222 |
+
|
223 |
+
# 添加片段索引
|
224 |
+
for result in results:
|
225 |
+
result.segment_index = i + 1 # 片段索引从1开始
|
226 |
+
|
227 |
+
# 添加到聚合器
|
228 |
+
for result in results:
|
229 |
+
logger.debug(f"添加转录结果: {result.text}")
|
230 |
+
self.aggregator.add_segment(result)
|
231 |
+
|
232 |
+
# 模拟处理延迟
|
233 |
+
time.sleep(0.1)
|
234 |
+
|
235 |
+
# 3. 最后强制刷新缓冲区
|
236 |
+
logger.debug("强制刷新缓冲区...")
|
237 |
+
self.aggregator.flush(force=True)
|
238 |
+
|
239 |
+
# 4. 等待所有优化任务完成
|
240 |
+
if self.enable_optimization and self.optimizer:
|
241 |
+
logger.debug("等待所有优化任务完成...")
|
242 |
+
self.optimizer.wait_until_done()
|
243 |
+
|
244 |
+
logger.info("音频处理完成")
|
245 |
+
|
246 |
+
def parse_args():
|
247 |
+
"""解析命令行参数"""
|
248 |
+
parser = argparse.ArgumentParser(description="伪流式音频转写 + LLM优化系统")
|
249 |
+
parser.add_argument("--audio_path", type=str, default="dataset/audio/test1.wav",
|
250 |
+
help="音频文件路径")
|
251 |
+
parser.add_argument("--use_gpu", action="store_true", default=True,
|
252 |
+
help="是否使用GPU")
|
253 |
+
parser.add_argument("--enable_translation", action="store_true", default=True,
|
254 |
+
help="是否启用翻译")
|
255 |
+
parser.add_argument("--enable_optimization", action="store_true", default=True,
|
256 |
+
help="是否启用LLM优化")
|
257 |
+
parser.add_argument("--whisper_model", type=str, default="small",
|
258 |
+
choices=["tiny", "base", "small", "medium", "large"],
|
259 |
+
help="Whisper模型大小")
|
260 |
+
parser.add_argument("--log_level", type=str, default="INFO",
|
261 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
262 |
+
help="日志级别")
|
263 |
+
return parser.parse_args()
|
264 |
+
|
265 |
+
def main():
|
266 |
+
"""主函数"""
|
267 |
+
# 解析命令行参数
|
268 |
+
args = parse_args()
|
269 |
+
|
270 |
+
# 设置日志级别
|
271 |
+
log_level = getattr(logging, args.log_level)
|
272 |
+
|
273 |
+
# 设置所有模块的日志级别
|
274 |
+
for module in ["main", "vad", "transcribe", "aggregator", "display", "optimizer", "translator"]:
|
275 |
+
setup_logger(module, log_level)
|
276 |
+
|
277 |
+
# 检查OpenAI API密钥(用于句子完整性判断和优化)
|
278 |
+
if not os.getenv("OPENAI_API_KEY") and args.enable_optimization:
|
279 |
+
logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
|
280 |
+
|
281 |
+
# 检查音频文件是否存在
|
282 |
+
if not os.path.exists(args.audio_path):
|
283 |
+
logger.error(f"音频文件不存在: {args.audio_path}")
|
284 |
+
return
|
285 |
+
|
286 |
+
# 创建并运行处理流水线
|
287 |
+
pipeline = AudioProcessingPipeline(
|
288 |
+
audio_path=args.audio_path,
|
289 |
+
use_gpu=args.use_gpu,
|
290 |
+
enable_translation=args.enable_translation,
|
291 |
+
enable_optimization=args.enable_optimization,
|
292 |
+
whisper_model=args.whisper_model,
|
293 |
+
log_level=log_level
|
294 |
+
)
|
295 |
+
|
296 |
+
# 处理音频
|
297 |
+
pipeline.process()
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
main()
|
optimizer/dispatcher.py
CHANGED
@@ -1,46 +1,127 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
import asyncio
|
|
|
3 |
from concurrent.futures import ThreadPoolExecutor
|
4 |
-
from
|
|
|
5 |
from optimizer.optimize_task import OptimizeTask
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
while True:
|
20 |
-
task: OptimizeTask = await self.queue.get()
|
21 |
-
asyncio.create_task(self._handle(task))
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
优化调度器 - 负责管理LLM优化任务队列
|
3 |
+
"""
|
4 |
+
|
5 |
import asyncio
|
6 |
+
import logging
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from typing import Callable, Optional
|
9 |
+
from optimizer.llm_api_runner import ChatGPTRunner
|
10 |
from optimizer.optimize_task import OptimizeTask
|
11 |
|
12 |
+
# 配置日志
|
13 |
+
def setup_logger(name, level=logging.INFO):
|
14 |
+
"""设置日志记录器"""
|
15 |
+
logger = logging.getLogger(name)
|
16 |
+
# 清除所有已有的handler,避免重复
|
17 |
+
if logger.handlers:
|
18 |
+
logger.handlers.clear()
|
19 |
+
|
20 |
+
# 添加新的handler
|
21 |
+
handler = logging.StreamHandler()
|
22 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
23 |
+
handler.setFormatter(formatter)
|
24 |
+
logger.addHandler(handler)
|
25 |
+
logger.setLevel(level)
|
26 |
+
# 禁止传播到父logger,避免重复日志
|
27 |
+
logger.propagate = False
|
28 |
+
return logger
|
29 |
|
30 |
+
# 创建日志记录器
|
31 |
+
logger = setup_logger("optimizer")
|
|
|
|
|
|
|
32 |
|
33 |
+
class OptimizationDispatcher:
|
34 |
+
"""
|
35 |
+
优化调度器,负责管理LLM优化任务队列
|
36 |
+
支持异步处理多个优化任务
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, max_workers: int = 2, callback: Optional[Callable] = None):
|
40 |
+
"""
|
41 |
+
初始化优化调度器
|
42 |
+
|
43 |
+
:param max_workers: 最大工作线程数
|
44 |
+
:param callback: 优化完成后的回调函数
|
45 |
+
"""
|
46 |
+
self.tasks = {} # 存储任务ID到任务的映射
|
47 |
+
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
48 |
+
self.model_runner = ChatGPTRunner()
|
49 |
+
self.callback = callback
|
50 |
+
logger.debug(f"优化调度器初始化完成,最大工作线程数: {max_workers}")
|
51 |
+
|
52 |
+
def submit(self, sentence_id: str, text: str, callback: Optional[Callable] = None):
|
53 |
+
"""
|
54 |
+
提交优化任务
|
55 |
+
|
56 |
+
:param sentence_id: 句子ID
|
57 |
+
:param text: 需要优化的文本
|
58 |
+
:param callback: 优化完成后的回调函数,如果为None则使用默认回调
|
59 |
+
"""
|
60 |
+
task_callback = callback or self.callback
|
61 |
+
task = OptimizeTask(sentence_id, text, task_callback)
|
62 |
+
self.tasks[sentence_id] = task
|
63 |
+
logger.debug(f"提交优化任务: {sentence_id}")
|
64 |
+
|
65 |
+
# 在线程池中执行任务
|
66 |
+
self.executor.submit(self._process_task, task)
|
67 |
+
logger.debug(f"任务已提交到线程池: {sentence_id}")
|
68 |
+
|
69 |
+
def _process_task(self, task: OptimizeTask):
|
70 |
+
"""
|
71 |
+
处理优化任务
|
72 |
+
|
73 |
+
:param task: 优化任务
|
74 |
+
"""
|
75 |
+
try:
|
76 |
+
logger.debug(f"开始处理任务: {task.sentence_id}")
|
77 |
+
# 使用模型运行器优化文本
|
78 |
+
optimized_text = self.model_runner.optimize(task.text)
|
79 |
+
logger.debug(f"任务处理完成: {task.sentence_id}")
|
80 |
+
|
81 |
+
# 调用回调函数
|
82 |
+
if task.callback:
|
83 |
+
task.callback(task.sentence_id, task.text, optimized_text)
|
84 |
+
logger.debug(f"已调用回调函数: {task.sentence_id}")
|
85 |
+
|
86 |
+
# 从任务列表中移除
|
87 |
+
if task.sentence_id in self.tasks:
|
88 |
+
del self.tasks[task.sentence_id]
|
89 |
+
|
90 |
+
logger.info(f"优化任务完成: {task.sentence_id}")
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"处理任务出错: {task.sentence_id}, 错误: {str(e)}")
|
93 |
+
|
94 |
+
def wait_until_done(self, timeout: Optional[float] = None):
|
95 |
+
"""
|
96 |
+
等待所有任务完成
|
97 |
+
|
98 |
+
:param timeout: 超时时间(秒),如果为None则一直等待
|
99 |
+
:return: 是否所有任务都已完成
|
100 |
+
"""
|
101 |
+
logger.debug(f"等待所有任务完成,当前任务数: {len(self.tasks)}")
|
102 |
+
self.executor.shutdown(wait=True, timeout=timeout)
|
103 |
+
# 创建新的线程池
|
104 |
+
self.executor = ThreadPoolExecutor(max_workers=self.executor._max_workers)
|
105 |
+
logger.debug("所有任务已完成")
|
106 |
+
return True
|
107 |
|
108 |
|
109 |
if __name__ == "__main__":
|
110 |
+
# 设置日志级别为DEBUG以查看详细信息
|
111 |
+
logger.setLevel(logging.DEBUG)
|
112 |
+
|
113 |
+
# 测试回调函数
|
114 |
+
def test_callback(sentence_id, original_text, optimized_text):
|
115 |
+
logger.info(f"[回填] {sentence_id}: {optimized_text}")
|
116 |
+
|
117 |
+
# 创建调度器
|
118 |
+
dispatcher = OptimizationDispatcher(callback=test_callback)
|
119 |
+
|
120 |
+
# 提交测试任务
|
121 |
+
dispatcher.submit("s001", "we maybe start tomorrow okay")
|
122 |
+
dispatcher.submit("s002", "they need eat fast meeting now")
|
123 |
+
|
124 |
+
# 等待任务完成
|
125 |
+
dispatcher.wait_until_done()
|
126 |
+
|
127 |
+
logger.info("测试完成")
|
optimizer/llm_api_runner.py
CHANGED
@@ -1,15 +1,61 @@
|
|
1 |
-
|
|
|
|
|
|
|
2 |
from openai import OpenAI
|
3 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
|
|
6 |
|
7 |
class ChatGPTRunner:
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
self.model = model
|
10 |
-
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def build_prompt(self, text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
return (
|
14 |
"示例:\n"
|
15 |
"原句:你门现在就得开会了,别迟到了。\n"
|
@@ -20,26 +66,60 @@ class ChatGPTRunner:
|
|
20 |
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
|
21 |
"原句:系统将进入留言模式,请耐行等待。\n"
|
22 |
"修改:系统将进入留言模式,请耐心等待。\n"
|
|
|
|
|
23 |
f"原句:{text}\n"
|
24 |
f"修改:"
|
25 |
)
|
26 |
|
27 |
def optimize(self, text: str, max_tokens: int = 256) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
prompt = self.build_prompt(text)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
41 |
runner = ChatGPTRunner(MODEL_NAME)
|
42 |
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
|
|
|
|
|
43 |
result = runner.optimize(test_input)
|
44 |
-
|
45 |
-
print("优化后:", result)
|
|
|
1 |
+
"""
|
2 |
+
ChatGPT优化器 - 使用OpenAI API优化转写结果
|
3 |
+
"""
|
4 |
+
|
5 |
from openai import OpenAI
|
6 |
import os
|
7 |
+
import logging
|
8 |
+
import time
|
9 |
+
|
10 |
+
# 配置日志
|
11 |
+
def setup_logger(name, level=logging.INFO):
|
12 |
+
"""设置日志记录器"""
|
13 |
+
logger = logging.getLogger(name)
|
14 |
+
# 清除所有已有的handler,避免重复
|
15 |
+
if logger.handlers:
|
16 |
+
logger.handlers.clear()
|
17 |
+
|
18 |
+
# 添加新的handler
|
19 |
+
handler = logging.StreamHandler()
|
20 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
21 |
+
handler.setFormatter(formatter)
|
22 |
+
logger.addHandler(handler)
|
23 |
+
logger.setLevel(level)
|
24 |
+
# 禁止传播到父logger,避免重复日志
|
25 |
+
logger.propagate = False
|
26 |
+
return logger
|
27 |
+
|
28 |
+
# 创建日志记录器
|
29 |
+
logger = setup_logger("optimizer.api")
|
30 |
|
31 |
+
# 默认模型
|
32 |
+
MODEL_NAME = "gpt-3.5-turbo"
|
33 |
|
34 |
class ChatGPTRunner:
|
35 |
+
"""
|
36 |
+
ChatGPT优化器,使用OpenAI API优化转写结果
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, model: str = MODEL_NAME):
|
40 |
+
"""
|
41 |
+
初始化ChatGPT优化器
|
42 |
+
|
43 |
+
:param model: 使用的模型名称
|
44 |
+
"""
|
45 |
self.model = model
|
46 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
47 |
+
if not api_key:
|
48 |
+
logger.warning("未设置OPENAI_API_KEY环境变量")
|
49 |
+
self.client = OpenAI(api_key=api_key)
|
50 |
+
logger.debug(f"ChatGPT优化器初始化完成,使用模型: {model}")
|
51 |
|
52 |
def build_prompt(self, text: str) -> str:
|
53 |
+
"""
|
54 |
+
构建优化提示
|
55 |
+
|
56 |
+
:param text: 需要优化的文本
|
57 |
+
:return: 构建好的提示
|
58 |
+
"""
|
59 |
return (
|
60 |
"示例:\n"
|
61 |
"原句:你门现在就得开会了,别迟到了。\n"
|
|
|
66 |
"修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
|
67 |
"原句:系统将进入留言模式,请耐行等待。\n"
|
68 |
"修改:系统将进入留言模式,请耐心等待。\n"
|
69 |
+
"原句:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。\n"
|
70 |
+
"修改:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法。\n"
|
71 |
f"原句:{text}\n"
|
72 |
f"修改:"
|
73 |
)
|
74 |
|
75 |
def optimize(self, text: str, max_tokens: int = 256) -> str:
|
76 |
+
"""
|
77 |
+
优化文本
|
78 |
+
|
79 |
+
:param text: 需要优化的文本
|
80 |
+
:param max_tokens: 最大生成token数
|
81 |
+
:return: 优化后的文本
|
82 |
+
"""
|
83 |
+
logger.debug(f"开始优化文本: {text}")
|
84 |
+
start_time = time.time()
|
85 |
+
|
86 |
+
# 构建提示
|
87 |
prompt = self.build_prompt(text)
|
88 |
+
|
89 |
+
try:
|
90 |
+
# 调用API
|
91 |
+
response = self.client.chat.completions.create(
|
92 |
+
model=self.model,
|
93 |
+
messages=[
|
94 |
+
{"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"},
|
95 |
+
{"role": "user", "content": prompt}
|
96 |
+
],
|
97 |
+
temperature=0.4,
|
98 |
+
max_tokens=max_tokens,
|
99 |
+
)
|
100 |
+
|
101 |
+
# 提取结果
|
102 |
+
result = response.choices[0].message.content.strip()
|
103 |
+
|
104 |
+
# 记录耗时
|
105 |
+
elapsed_time = time.time() - start_time
|
106 |
+
logger.debug(f"优化完成,耗时: {elapsed_time:.2f}秒")
|
107 |
+
logger.info(f"优化结果: {result}")
|
108 |
+
|
109 |
+
return result
|
110 |
+
except Exception as e:
|
111 |
+
logger.error(f"优化失败: {str(e)}")
|
112 |
+
# 出错时返回原文
|
113 |
+
return text
|
114 |
|
115 |
if __name__ == "__main__":
|
116 |
+
# 设置日志级别为DEBUG以查看详细信息
|
117 |
+
logger.setLevel(logging.DEBUG)
|
118 |
+
|
119 |
+
# ��试优化
|
120 |
runner = ChatGPTRunner(MODEL_NAME)
|
121 |
test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
|
122 |
+
|
123 |
+
logger.info(f"优化前: {test_input}")
|
124 |
result = runner.optimize(test_input)
|
125 |
+
logger.info(f"优化后: {result}")
|
|
optimizer/optimize_task.py
CHANGED
@@ -1,23 +1,74 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class OptimizeTask:
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
self.sentence_id = sentence_id
|
8 |
self.text = text
|
9 |
self.callback = callback
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
14 |
|
15 |
|
16 |
if __name__ == "__main__":
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
task = OptimizeTask("s001", "they go home maybe tomorrow", fake_callback)
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
优化任务 - 表示一个LLM优化任务
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
from typing import Callable, Optional
|
7 |
+
|
8 |
+
# 配置日志
|
9 |
+
def setup_logger(name, level=logging.INFO):
|
10 |
+
"""设置日志记录器"""
|
11 |
+
logger = logging.getLogger(name)
|
12 |
+
# 清除所有已有的handler,避免重复
|
13 |
+
if logger.handlers:
|
14 |
+
logger.handlers.clear()
|
15 |
+
|
16 |
+
# 添加新的handler
|
17 |
+
handler = logging.StreamHandler()
|
18 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
19 |
+
handler.setFormatter(formatter)
|
20 |
+
logger.addHandler(handler)
|
21 |
+
logger.setLevel(level)
|
22 |
+
# 禁止传播到父logger,避免重复日志
|
23 |
+
logger.propagate = False
|
24 |
+
return logger
|
25 |
+
|
26 |
+
# 创建日志记录器
|
27 |
+
logger = setup_logger("optimizer.task")
|
28 |
|
29 |
class OptimizeTask:
|
30 |
+
"""
|
31 |
+
优化任务,表示一个需要LLM优化的文本任务
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, sentence_id: str, text: str, callback: Optional[Callable[[str, str, str], None]] = None):
|
35 |
+
"""
|
36 |
+
初始化优化任务
|
37 |
+
|
38 |
+
:param sentence_id: 句子ID
|
39 |
+
:param text: 需要优化的文本
|
40 |
+
:param callback: 优化完成后的回调函数,接收参数(sentence_id, original_text, optimized_text)
|
41 |
+
"""
|
42 |
self.sentence_id = sentence_id
|
43 |
self.text = text
|
44 |
self.callback = callback
|
45 |
+
logger.debug(f"创建优化任务: {sentence_id}")
|
46 |
+
|
47 |
+
def __str__(self):
|
48 |
+
"""字符串表示"""
|
49 |
+
return f"OptimizeTask(id={self.sentence_id}, text={self.text[:20]}...)"
|
50 |
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
+
# 设置日志级别为DEBUG以查看详细信息
|
54 |
+
logger.setLevel(logging.DEBUG)
|
55 |
+
|
56 |
+
# 测试回调函数
|
57 |
+
def fake_callback(sid, original_text, optimized_text):
|
58 |
+
logger.info(f"[回调] 优化结果:({sid}) -> {optimized_text}")
|
59 |
+
|
60 |
+
# 创建任务
|
61 |
task = OptimizeTask("s001", "they go home maybe tomorrow", fake_callback)
|
62 |
+
|
63 |
+
# 创建模型运行器
|
64 |
+
from optimizer.llm_api_runner import ChatGPTRunner
|
65 |
+
runner = ChatGPTRunner()
|
66 |
+
|
67 |
+
# 优化文本
|
68 |
+
optimized_text = runner.optimize(task.text)
|
69 |
+
|
70 |
+
# 调用回调
|
71 |
+
if task.callback:
|
72 |
+
task.callback(task.sentence_id, task.text, optimized_text)
|
73 |
+
|
74 |
+
logger.info("测试完成")
|
transcribe/transcribe.py
CHANGED
@@ -21,11 +21,25 @@ class TranscriptionResult:
|
|
21 |
segment_index: Optional[int] = None # 添加片段索引字段
|
22 |
|
23 |
# 配置日志
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
handler
|
28 |
-
logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
class AudioTranscriber:
|
31 |
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
|
@@ -44,12 +58,12 @@ class AudioTranscriber:
|
|
44 |
log_level = getattr(logging, log_level.upper())
|
45 |
logger.setLevel(log_level)
|
46 |
|
47 |
-
logger.debug("
|
48 |
|
49 |
from faster_whisper import WhisperModel
|
50 |
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
51 |
|
52 |
-
logger.debug("
|
53 |
|
54 |
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
|
55 |
"""
|
@@ -64,7 +78,6 @@ class AudioTranscriber:
|
|
64 |
"""
|
65 |
start_process_time = time.time()
|
66 |
|
67 |
-
logger.debug("Model transcribe...")
|
68 |
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
|
69 |
|
70 |
try:
|
@@ -74,9 +87,9 @@ class AudioTranscriber:
|
|
74 |
|
75 |
segments = list(segments_generator)
|
76 |
|
77 |
-
logger.debug(f"
|
78 |
if len(segments) > 0:
|
79 |
-
logger.debug(f"
|
80 |
|
81 |
results = []
|
82 |
for seg in segments:
|
@@ -141,6 +154,7 @@ class AudioTranscriber:
|
|
141 |
with open(output_path, 'w', encoding='utf-8') as f:
|
142 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
143 |
|
|
|
144 |
return output_path
|
145 |
|
146 |
|
@@ -149,12 +163,11 @@ if __name__ == "__main__":
|
|
149 |
audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
|
150 |
import soundfile as sf
|
151 |
|
152 |
-
#
|
153 |
-
|
154 |
-
processor = AudioTranscriber(log_level="DEBUG") # 或 log_level=logging.INFO
|
155 |
|
156 |
-
#
|
157 |
-
|
158 |
|
159 |
try:
|
160 |
audio_data, sample_rate = sf.read(audio_path)
|
|
|
21 |
segment_index: Optional[int] = None # 添加片段索引字段
|
22 |
|
23 |
# 配置日志
|
24 |
+
def setup_logger(name, level=logging.INFO):
|
25 |
+
"""设置日志记录器"""
|
26 |
+
logger = logging.getLogger(name)
|
27 |
+
# 清除所有已有的handler,避免重复
|
28 |
+
if logger.handlers:
|
29 |
+
logger.handlers.clear()
|
30 |
+
|
31 |
+
# 添加新的handler
|
32 |
+
handler = logging.StreamHandler()
|
33 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
34 |
+
handler.setFormatter(formatter)
|
35 |
+
logger.addHandler(handler)
|
36 |
+
logger.setLevel(level)
|
37 |
+
# 禁止传播到父logger,避免重复日志
|
38 |
+
logger.propagate = False
|
39 |
+
return logger
|
40 |
+
|
41 |
+
# 创建日志记录器
|
42 |
+
logger = setup_logger("transcribe")
|
43 |
|
44 |
class AudioTranscriber:
|
45 |
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
|
|
|
58 |
log_level = getattr(logging, log_level.upper())
|
59 |
logger.setLevel(log_level)
|
60 |
|
61 |
+
logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}")
|
62 |
|
63 |
from faster_whisper import WhisperModel
|
64 |
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
65 |
|
66 |
+
logger.debug("Whisper模型加载完成")
|
67 |
|
68 |
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
|
69 |
"""
|
|
|
78 |
"""
|
79 |
start_process_time = time.time()
|
80 |
|
|
|
81 |
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
|
82 |
|
83 |
try:
|
|
|
87 |
|
88 |
segments = list(segments_generator)
|
89 |
|
90 |
+
logger.debug(f"转录成功,片段数: {len(segments)}")
|
91 |
if len(segments) > 0:
|
92 |
+
logger.debug(f"第一个片段: {segments[0]}")
|
93 |
|
94 |
results = []
|
95 |
for seg in segments:
|
|
|
154 |
with open(output_path, 'w', encoding='utf-8') as f:
|
155 |
json.dump(data, f, ensure_ascii=False, indent=2)
|
156 |
|
157 |
+
logger.info(f"转录结果已保存到: {output_path}")
|
158 |
return output_path
|
159 |
|
160 |
|
|
|
163 |
audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
|
164 |
import soundfile as sf
|
165 |
|
166 |
+
# 设置日志级别为DEBUG以查看详细信息
|
167 |
+
logger.setLevel(logging.DEBUG)
|
|
|
168 |
|
169 |
+
# 初始化转录器
|
170 |
+
processor = AudioTranscriber(log_level="DEBUG")
|
171 |
|
172 |
try:
|
173 |
audio_data, sample_rate = sf.read(audio_path)
|
translator/translator.py
CHANGED
@@ -1,20 +1,56 @@
|
|
1 |
-
|
|
|
|
|
2 |
|
3 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
4 |
from langdetect import detect
|
5 |
import torch
|
6 |
import time
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
class NLLBTranslator:
|
|
|
|
|
|
|
|
|
10 |
def __init__(self, model_name="facebook/nllb-200-distilled-600M", default_target="eng_Latn"):
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
-
|
|
|
13 |
if self.device.type == "cuda":
|
14 |
-
|
15 |
total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
16 |
-
|
17 |
|
|
|
|
|
18 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
19 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
20 |
model_name,
|
@@ -22,51 +58,92 @@ class NLLBTranslator:
|
|
22 |
).to(self.device)
|
23 |
|
24 |
self.default_target = default_target
|
|
|
25 |
|
26 |
def detect_lang_code(self, text: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
try:
|
28 |
lang = detect(text)
|
|
|
29 |
except Exception:
|
30 |
-
|
31 |
lang = "zh-cn"
|
32 |
|
|
|
33 |
lang_map = {
|
34 |
"zh-cn": "zho_Hans", "zh": "zho_Hans", "en": "eng_Latn", "fr": "fra_Latn",
|
35 |
"de": "deu_Latn", "ja": "jpn_Jpan", "ko": "kor_Hang", "ar": "arb_Arab"
|
36 |
}
|
|
|
37 |
lang_code = lang_map.get(lang.lower(), "eng_Latn")
|
38 |
-
|
39 |
return lang_code
|
40 |
|
41 |
def translate(self, text: str, target_lang_code: str = None) -> str:
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
src_lang = self.detect_lang_code(text)
|
46 |
tgt_lang = target_lang_code or self.default_target
|
47 |
|
|
|
48 |
self.tokenizer.src_lang = src_lang
|
49 |
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
50 |
inputs["forced_bos_token_id"] = self.tokenizer.convert_tokens_to_ids(tgt_lang)
|
51 |
|
|
|
52 |
start = time.time()
|
53 |
with torch.no_grad():
|
54 |
output = self.model.generate(**inputs, max_new_tokens=80)
|
|
|
|
|
55 |
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
return result
|
58 |
|
59 |
|
60 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
61 |
translator = NLLBTranslator()
|
62 |
|
|
|
63 |
zh_text = "你会学习到如何使用音频数据集"
|
64 |
-
|
65 |
-
|
|
|
66 |
|
|
|
67 |
en_text = "This audio processing pipeline is fast and accurate."
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
翻译模块 - 使用NLLB模型进行多语言翻译
|
3 |
+
"""
|
4 |
|
5 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
6 |
from langdetect import detect
|
7 |
import torch
|
8 |
import time
|
9 |
+
import logging
|
10 |
+
|
11 |
+
# 配置日志
|
12 |
+
def setup_logger(name, level=logging.INFO):
|
13 |
+
"""设置日志记录器"""
|
14 |
+
logger = logging.getLogger(name)
|
15 |
+
# 清除所有已有的handler,避免重复
|
16 |
+
if logger.handlers:
|
17 |
+
logger.handlers.clear()
|
18 |
+
|
19 |
+
# 添加新的handler
|
20 |
+
handler = logging.StreamHandler()
|
21 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
22 |
+
handler.setFormatter(formatter)
|
23 |
+
logger.addHandler(handler)
|
24 |
+
logger.setLevel(level)
|
25 |
+
# 禁止传播到父logger,避免重复日志
|
26 |
+
logger.propagate = False
|
27 |
+
return logger
|
28 |
+
|
29 |
+
# 创建日志记录器
|
30 |
+
logger = setup_logger("translator")
|
31 |
|
32 |
class NLLBTranslator:
|
33 |
+
"""
|
34 |
+
NLLB翻译器,使用Facebook的NLLB模型进行多语言翻译
|
35 |
+
"""
|
36 |
+
|
37 |
def __init__(self, model_name="facebook/nllb-200-distilled-600M", default_target="eng_Latn"):
|
38 |
+
"""
|
39 |
+
初始化NLLB翻译器
|
40 |
+
|
41 |
+
:param model_name: 模型名称
|
42 |
+
:param default_target: 默认目标语言代码
|
43 |
+
"""
|
44 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
logger.debug(f"使用设备: {self.device}")
|
46 |
+
|
47 |
if self.device.type == "cuda":
|
48 |
+
logger.debug(f"GPU设备: {torch.cuda.get_device_name(0)}")
|
49 |
total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
50 |
+
logger.debug(f"GPU显存: {total_mem:.1f} GB")
|
51 |
|
52 |
+
# 加载模型和分词器
|
53 |
+
logger.debug(f"加载模型: {model_name}")
|
54 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
|
55 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(
|
56 |
model_name,
|
|
|
58 |
).to(self.device)
|
59 |
|
60 |
self.default_target = default_target
|
61 |
+
logger.debug(f"翻译器初始化完成,默认目标语言: {default_target}")
|
62 |
|
63 |
def detect_lang_code(self, text: str) -> str:
|
64 |
+
"""
|
65 |
+
检测文本语言并返回NLLB语言代码
|
66 |
+
|
67 |
+
:param text: 要检测的文本
|
68 |
+
:return: NLLB语言代码
|
69 |
+
"""
|
70 |
try:
|
71 |
lang = detect(text)
|
72 |
+
logger.debug(f"检测到语言: {lang}")
|
73 |
except Exception:
|
74 |
+
logger.debug("语言检测失败,默认使用中文(zh)")
|
75 |
lang = "zh-cn"
|
76 |
|
77 |
+
# 语言代码映射
|
78 |
lang_map = {
|
79 |
"zh-cn": "zho_Hans", "zh": "zho_Hans", "en": "eng_Latn", "fr": "fra_Latn",
|
80 |
"de": "deu_Latn", "ja": "jpn_Jpan", "ko": "kor_Hang", "ar": "arb_Arab"
|
81 |
}
|
82 |
+
|
83 |
lang_code = lang_map.get(lang.lower(), "eng_Latn")
|
84 |
+
logger.debug(f"映射语言代码: {lang} -> {lang_code}")
|
85 |
return lang_code
|
86 |
|
87 |
def translate(self, text: str, target_lang_code: str = None) -> str:
|
88 |
+
"""
|
89 |
+
翻译文本到目标语言
|
90 |
+
|
91 |
+
:param text: 要翻译的文本
|
92 |
+
:param target_lang_code: 目标语言代码,如果为None则使用默认目标语言
|
93 |
+
:return: 翻译后的文本
|
94 |
+
"""
|
95 |
+
logger.debug("开始翻译")
|
96 |
+
|
97 |
+
# 记录原文(INFO级别)
|
98 |
+
logger.info(f"[翻译原文] {text}")
|
99 |
+
|
100 |
+
# 检测源语言
|
101 |
src_lang = self.detect_lang_code(text)
|
102 |
tgt_lang = target_lang_code or self.default_target
|
103 |
|
104 |
+
# 准备输入
|
105 |
self.tokenizer.src_lang = src_lang
|
106 |
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
107 |
inputs["forced_bos_token_id"] = self.tokenizer.convert_tokens_to_ids(tgt_lang)
|
108 |
|
109 |
+
# 执行翻译
|
110 |
start = time.time()
|
111 |
with torch.no_grad():
|
112 |
output = self.model.generate(**inputs, max_new_tokens=80)
|
113 |
+
|
114 |
+
# 解码结果
|
115 |
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
116 |
+
|
117 |
+
# 记录耗时和结果
|
118 |
+
duration = time.time() - start
|
119 |
+
logger.debug(f"翻译完成: {src_lang} -> {tgt_lang}, 耗时: {duration:.2f}秒")
|
120 |
+
|
121 |
+
# 记录翻译结果(INFO级别)
|
122 |
+
logger.info(f"[翻译结果] {result}")
|
123 |
+
|
124 |
return result
|
125 |
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
+
# 设置日志级别为DEBUG以查看详细信息
|
129 |
+
logger.setLevel(logging.DEBUG)
|
130 |
+
|
131 |
+
# 创建翻译器
|
132 |
translator = NLLBTranslator()
|
133 |
|
134 |
+
# 测试中文到英文
|
135 |
zh_text = "你会学习到如何使用音频数据集"
|
136 |
+
logger.info("\n==== 中文 → 英文 ====")
|
137 |
+
result = translator.translate(zh_text, target_lang_code="eng_Latn")
|
138 |
+
logger.info(f"测试完成: {result}")
|
139 |
|
140 |
+
# 测试英文到法语
|
141 |
en_text = "This audio processing pipeline is fast and accurate."
|
142 |
+
logger.info("\n==== 英文 → 法语 ====")
|
143 |
+
result = translator.translate(en_text, target_lang_code="fra_Latn")
|
144 |
+
logger.info(f"测试完成: {result}")
|
145 |
+
|
146 |
+
# 测试英文到阿拉伯语
|
147 |
+
logger.info("\n==== 英文 → 阿拉伯语 ====")
|
148 |
+
result = translator.translate(en_text, target_lang_code="arb_Arab")
|
149 |
+
logger.info(f"测试完成: {result}")
|
vad/__init__.py
CHANGED
@@ -1,3 +1,36 @@
|
|
1 |
from .vad import AudioVad, AudioSegment
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from .vad import AudioVad, AudioSegment
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Tuple
|
4 |
|
5 |
+
class VoiceActivityDetector:
|
6 |
+
"""
|
7 |
+
VAD检测器,用于检测音频中的语音片段
|
8 |
+
这是一个包装类,内部使用AudioVad实现功能
|
9 |
+
"""
|
10 |
+
|
11 |
+
def __init__(self, save_audio=True, save_json=True):
|
12 |
+
"""
|
13 |
+
初始化VAD检测器
|
14 |
+
|
15 |
+
:param save_audio: 是否保存分段音频
|
16 |
+
:param save_json: 是否保存JSON元数据
|
17 |
+
"""
|
18 |
+
self.vad = AudioVad(
|
19 |
+
save_audio=save_audio,
|
20 |
+
save_json=save_json,
|
21 |
+
output_dir="dataset/audio/segments",
|
22 |
+
json_dir="dataset/audio/metadata"
|
23 |
+
)
|
24 |
+
|
25 |
+
def detect_voice_segments(self, audio_data: np.ndarray, sample_rate: int) -> List[Tuple[float, float]]:
|
26 |
+
"""
|
27 |
+
检测音频中的语音片段
|
28 |
+
|
29 |
+
:param audio_data: 音频数据
|
30 |
+
:param sample_rate: 采样率
|
31 |
+
:return: 语音片段列表,每个片段为(开始时间, 结束时间)的元组
|
32 |
+
"""
|
33 |
+
segments = self.vad.process_audio_data(audio_data, sample_rate)
|
34 |
+
return [(segment.start_time, segment.end_time) for segment in segments]
|
35 |
+
|
36 |
+
__all__ = ['AudioVad', 'AudioSegment', 'VoiceActivityDetector']
|