tianyaogavin commited on
Commit
1bf36cc
·
1 Parent(s): 401b3f7

init main framework

Browse files
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Whisper API Server
3
  emoji: 🎙️
4
  colorFrom: indigo
5
  colorTo: pink
@@ -8,17 +8,145 @@ app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- # Whisper API Server with faster-whisper
12
 
13
- This Space provides a REST API to transcribe audio using faster-whisper + FastAPI.
14
 
15
- ## API Endpoints
16
 
17
- - `GET /` → Health check
18
- - `POST /transcribe` → Upload a `.wav/.mp3` file and receive transcript text
 
 
 
19
 
20
- ## Example usage (cURL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  ```bash
23
  curl -X POST https://your-space-name.hf.space/transcribe \
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: 伪流式音频转写 + LLM优化系统
3
  emoji: 🎙️
4
  colorFrom: indigo
5
  colorTo: pink
 
8
  pinned: false
9
  ---
10
 
11
+ # 伪流式音频转写 + LLM优化系统
12
 
13
+ 这个项目实现了一个伪流式音频转写系统,包括VAD分段、Whisper转录、语义聚合、LLM优化和翻译等功能。系统采用模块化设计,各个组件可以独立工作,也可以组合使用。
14
 
15
+ ## 系统架构
16
 
17
+ ```mermaid
18
+ graph TD
19
+ A[音频流输入] --> B[VAD]
20
+ B --> C[Transcribe]
21
+ C --> D[语义聚合控制器]
22
 
23
+ D --> E[即时输出模块]
24
+ D --> F[LLM 优化调度器]
25
+
26
+ F --> G[优化后回填模块]
27
+ G --> E
28
+ E --> H[翻译模块]
29
+ ```
30
+
31
+ ## 主要模块
32
+
33
+ - **VAD分段器**: 根据能量、静音、说话边界等信号判断语音段落
34
+ - **Whisper转录模块**: 对每段VAD输出进行whisper转写,输出文本+时间戳
35
+ - **语义聚合控制器**: 维护segment缓冲池,判断是否组成完整语义单元,推送到下游
36
+ - **即时输出模块**: 将聚合后的转写结果立即显示给用户
37
+ - **LLM优化调度器**: 接收待优化句子,加入优化任务队列
38
+ - **优化后回填模块**: 对照原句编号,将LLM优化结果回填替换
39
+ - **翻译模块**: 接收所有来自即时输出模块的句子,将其翻译为目标语言
40
+
41
+ ## 语义聚合控制器
42
+
43
+ 语义聚合控制器是系统的核心模块,负责将多个音频片段的转录结果聚合成完整的语义单元(句子),并推送到下游模块(显示和翻译)。
44
+
45
+ ### 主要功能
46
+
47
+ 1. **维护转录片段缓冲池**:收集来自转录模块的片段,直到形成完整语义单元
48
+ 2. **判断语义完整性**:使用ChatGPT进行few-shot学习,判断多个片段是否组成完整句子
49
+ 3. **重新转录**:将多个片段的音频合并,进行整体重新转录,提高准确性
50
+ 4. **推送到下游**:将聚合结果发送到显示模块和翻译模块
51
+
52
+ 详细信息请参考 [aggregator/README.md](aggregator/README.md)。
53
+
54
+ ## 使用示例
55
+
56
+ ### 完整流程示例
57
+
58
+ ```python
59
+ from vad.vad import VoiceActivityDetector
60
+ from transcribe.transcribe import AudioTranscriber
61
+ from display.display import OutputRenderer
62
+ from translator.translator import NLLBTranslator
63
+ from aggregator.semantic_aggregator import SemanticAggregator
64
+
65
+ # 初始化各个模块
66
+ vad = VoiceActivityDetector()
67
+ transcriber = AudioTranscriber(model="small", device="cuda")
68
+ renderer = OutputRenderer()
69
+ translator = NLLBTranslator()
70
+
71
+ # 回调函数
72
+ def display_callback(sentence_id, text, state):
73
+ renderer.display(sentence_id, text, state)
74
+
75
+ def translate_callback(sentence_id, text):
76
+ translation = translator.translate(text)
77
+ print(f"[翻译] 句子 {sentence_id}: {translation}")
78
+
79
+ # 初始化聚合器
80
+ aggregator = SemanticAggregator(
81
+ on_display=display_callback,
82
+ on_translate=translate_callback,
83
+ transcriber=transcriber
84
+ )
85
+
86
+ # 处理音频
87
+ audio_data, sample_rate = sf.read("audio.wav")
88
+ segments = vad.detect_voice_segments(audio_data, sample_rate)
89
+
90
+ for i, (start, end) in enumerate(segments):
91
+ segment_audio = audio_data[int(start * sample_rate):int(end * sample_rate)]
92
+ results = transcriber.transcribe_segment(segment_audio, start_time=start)
93
+
94
+ for result in results:
95
+ result.segment_index = i + 1
96
+ aggregator.add_segment(result)
97
+
98
+ # 最后强制刷新缓冲区
99
+ aggregator.flush(force=True)
100
+ ```
101
+
102
+ 更详细的示例请参考 [aggregator/integration_example.py](aggregator/integration_example.py)。
103
+
104
+ ## API服务
105
+
106
+ 系统也提供了REST API服务,可以通过HTTP请求进行音频转写。
107
+
108
+ ### API端点
109
+
110
+ - `GET /` → 健康检查
111
+ - `POST /transcribe` → 上传`.wav/.mp3`文件并接收转写文本
112
+
113
+ ### 使用示例 (cURL)
114
 
115
  ```bash
116
  curl -X POST https://your-space-name.hf.space/transcribe \
117
118
+ ```
119
+
120
+ ## 安装与运行
121
+
122
+ ### 环境要求
123
+
124
+ - Python 3.8+
125
+ - PyTorch 1.12+
126
+ - CUDA 11.6+ (如果使用GPU)
127
+
128
+ ### 安装依赖
129
+
130
+ ```bash
131
+ pip install -r requirements.txt
132
+ ```
133
+
134
+ ### 运行API服务
135
+
136
+ ```bash
137
+ python app.py
138
+ ```
139
+
140
+ ### 运行集成示例
141
+
142
+ ```bash
143
+ # 设置OpenAI API密钥(用于句子完整性判断)
144
+ export OPENAI_API_KEY=your_api_key
145
+
146
+ # 运行集成示例
147
+ python -m aggregator.integration_example
148
+ ```
149
+
150
+ ## 许可证
151
+
152
+ [MIT License](LICENSE)
aggregator/README.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 语义聚合控制器 (Semantic Aggregator)
2
+
3
+ 语义聚合控制器是伪流式音频转写系统的核心模块,负责将多个音频片段的转录结果聚合成完整的语义单元(句子),并推送到下游模块(显示和翻译)。
4
+
5
+ ## 主要功能
6
+
7
+ 1. **维护转录片段缓冲池**:收集来自转录模块的片段,直到形成完整语义单元
8
+ 2. **判断语义完整性**:使用ChatGPT进行few-shot学习,判断多个片段是否组成完整句子
9
+ 3. **重新转录**:将多个片段的音频合并,进行整体重新转录,提高准确性
10
+ 4. **推送到下游**:将聚合结果发送到显示模块和翻译模块
11
+
12
+ ## 核心组件
13
+
14
+ ### SentenceCompletionDetector
15
+
16
+ 使用ChatGPT进行few-shot学习,判断文本是否是一个完整的句子。
17
+
18
+ ```python
19
+ detector = SentenceCompletionDetector()
20
+ is_complete = detector.is_sentence_complete("你会学习到如何使用音频数据集") # False
21
+ is_complete = detector.is_sentence_complete("你会学习到如何使用音频数据集。") # True
22
+ ```
23
+
24
+ ### SemanticAggregator
25
+
26
+ 主要聚合控制器,负责缓冲、判断、重新转录和推送。
27
+
28
+ ```python
29
+ aggregator = SemanticAggregator(
30
+ on_display=display_callback, # 显示回调
31
+ on_translate=translate_callback, # 翻译回调
32
+ transcriber=transcriber, # 转录器实例
33
+ segments_dir="dataset/audio/segments", # 音频片段目录
34
+ max_window=5.0, # 最大聚合时长(秒)
35
+ max_segments=5, # 最大聚合片段数
36
+ min_gap=0.8, # 触发聚合的最小间隔(秒)
37
+ force_flush_timeout=3.0 # 强制flush超时时间(秒)
38
+ )
39
+ ```
40
+
41
+ ## 聚合判断逻辑
42
+
43
+ 聚合器使用以下逻辑判断是否应该聚合并输出:
44
+
45
+ 1. **语义完整性**:使用ChatGPT判断当前缓冲区中的文本是否形成完整句子
46
+ 2. **时间间隔**:如果相邻片段之间的间隔超过阈值,认为是不同的语义单元
47
+ 3. **最大窗口**:如果聚合的总时长超过阈值,强制聚合
48
+ 4. **最大片段数**:如果聚合的片段数超过阈值,强制聚合
49
+ 5. **超时机制**:如果长时间没有新片段,强制输出当前缓冲区内容
50
+
51
+ ## 重新转录流程
52
+
53
+ 1. 获取所有片段的音频数据
54
+ 2. 合并音频数据
55
+ 3. 使用转录器重新转录合并后的音频
56
+ 4. 比较重新转录结果与原始聚合结果
57
+ 5. 如果有差异,更新显示并发送到翻译模块
58
+
59
+ ## 使用示例
60
+
61
+ ```python
62
+ from display.display import OutputRenderer
63
+ from translator.translator import NLLBTranslator
64
+ from transcribe.transcribe import AudioTranscriber, TranscriptionResult
65
+ from aggregator.semantic_aggregator import SemanticAggregator
66
+
67
+ # 初始化各个模块
68
+ renderer = OutputRenderer()
69
+ translator = NLLBTranslator()
70
+ transcriber = AudioTranscriber(model="small", device="cuda")
71
+
72
+ # 回调函数
73
+ def display_callback(sentence_id, text, state):
74
+ renderer.display(sentence_id, text, state)
75
+
76
+ def translate_callback(sentence_id, text):
77
+ translation = translator.translate(text)
78
+ print(f"[翻译] 句子 {sentence_id}: {translation}")
79
+
80
+ # 初始化聚合器
81
+ aggregator = SemanticAggregator(
82
+ on_display=display_callback,
83
+ on_translate=translate_callback,
84
+ transcriber=transcriber
85
+ )
86
+
87
+ # 添加转录结果
88
+ for result in transcription_results:
89
+ aggregator.add_segment(result)
90
+
91
+ # 最后强制刷新缓冲区
92
+ aggregator.flush(force=True)
93
+ ```
94
+
95
+ ## 测试
96
+
97
+ 可以使用 `test_aggregator.py` 脚本测试聚合器功能:
98
+
99
+ ```bash
100
+ # 设置OpenAI API密钥
101
+ export OPENAI_API_KEY=your_api_key
102
+
103
+ # 运行测试脚本
104
+ python -m aggregator.test_aggregator
105
+ ```
106
+
107
+ ## 注意事项
108
+
109
+ 1. 需要设置 `OPENAI_API_KEY` 环境变量才能使用ChatGPT进行句子完整性判断
110
+ 2. 音频片段目录需要包含所有需要重新转录的音频文件
111
+ 3. 转录器需要正确初始化,包括模型、设备和计算类型
112
+ 4. 回调函数需要正确处理聚合结果,包括显示和翻译
aggregator/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 语义聚合控制器模块
3
+
4
+ 这个模块负责将多个音频片段的转录结果聚合成完整的语义单元(句子),
5
+ 并推送到下游模块(显示和翻译)。
6
+ """
7
+
8
+ from .semantic_aggregator import SemanticAggregator, SentenceCompletionDetector
9
+
10
+ __all__ = ['SemanticAggregator', 'SentenceCompletionDetector']
aggregator/semantic_aggregator.py CHANGED
@@ -1,47 +1,140 @@
1
- from typing import List, Callable, Optional
2
- from vad.audio_transcriber import TranscriptionResult
3
  import uuid
4
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class SemanticAggregator:
7
  """
8
  语义聚合控制器
9
  - 维护segment缓冲池
10
  - 判断是否组成完整语义单元
11
- - 推送到下游(display/optimizer
12
  """
13
 
14
  def __init__(
15
  self,
16
- on_aggregate: Callable[[str, List[TranscriptionResult], str], None],
 
 
 
17
  max_window: float = 5.0,
18
  max_segments: int = 5,
19
  min_gap: float = 0.8,
20
  force_flush_timeout: float = 3.0
21
  ):
22
  """
23
- :param on_aggregate: 聚合回调 (text, segments, sentence_id)
 
 
 
24
  :param max_window: 最大聚合时长(秒)
25
  :param max_segments: 最大聚合片段数
26
  :param min_gap: 触发聚合的最小间隔(秒)
27
  :param force_flush_timeout: 强制flush超时时间(秒)
28
  """
29
  self.buffer: List[TranscriptionResult] = []
30
- self.on_aggregate = on_aggregate
 
 
 
31
  self.max_window = max_window
32
  self.max_segments = max_segments
33
  self.min_gap = min_gap
34
  self.force_flush_timeout = force_flush_timeout
35
  self.last_flush_time = time.time()
 
 
 
 
36
 
37
  def add_segment(self, result: TranscriptionResult):
38
  """
39
  新增转写片段到缓冲池,自动判断是否聚合
40
  """
41
  self.buffer.append(result)
 
42
  if self._should_aggregate():
43
  self._aggregate_and_flush()
44
  elif time.time() - self.last_flush_time > self.force_flush_timeout:
 
45
  self.flush(force=True)
46
 
47
  def flush(self, force: bool = False):
@@ -49,6 +142,7 @@ class SemanticAggregator:
49
  强制输出当前聚合内容
50
  """
51
  if self.buffer:
 
52
  self._aggregate_and_flush()
53
  self.last_flush_time = time.time()
54
 
@@ -58,49 +152,259 @@ class SemanticAggregator:
58
  """
59
  if not self.buffer:
60
  return False
61
- # 1. 标点符号结尾
62
- if self.buffer[-1].text.strip() and self.buffer[-1].text.strip()[-1] in "。!?!?":
 
 
 
 
 
63
  return True
 
64
  # 2. segment间隔
65
  if len(self.buffer) >= 2:
66
  gap = self.buffer[-1].start_time - self.buffer[-2].end_time
67
  if gap > self.min_gap:
 
68
  return True
 
69
  # 3. 最大窗口/片段数
70
  total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
71
- if total_duration > self.max_window or len(self.buffer) >= self.max_segments:
 
 
 
 
72
  return True
 
73
  return False
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _aggregate_and_flush(self):
76
  """
77
  聚合并推送到下游
78
  """
79
- text = "".join([seg.text for seg in self.buffer])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  sentence_id = str(uuid.uuid4())
81
- self.on_aggregate(text, self.buffer, sentence_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  self.buffer.clear()
83
  self.last_flush_time = time.time()
 
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if __name__ == "__main__":
87
- # 示例:如何集成display和optimizer(无函数,主流程直写)
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  from display.display import OutputRenderer
89
- from optimizer.dispatcher import OptimizationDispatcher
90
-
 
 
91
  renderer = OutputRenderer()
92
- dispatcher = OptimizationDispatcher(max_workers=2)
93
-
94
- def aggregate_callback(text, segments, sentence_id):
95
- # 直接在主流程内联调用
96
- print(f"[聚合完成] sentence_id={sentence_id}")
97
- print(f"聚合文本: {text}")
98
- renderer.display(sentence_id, text, state="raw")
99
- dispatcher.submit(sentence_id, text, callback=None) # callback可自定义
100
-
101
- aggregator = SemanticAggregator(on_aggregate=aggregate_callback)
102
-
103
- # 假设有若干TranscriptionResult对象results
104
- # for result in results:
105
- # aggregator.add_segment(result)
106
- # aggregator.flush(force=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Callable, Optional, Dict, Tuple
 
2
  import uuid
3
  import time
4
+ import os
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import logging
8
+ from openai import OpenAI
9
+ from transcribe.transcribe import TranscriptionResult, AudioTranscriber
10
+
11
+ # 配置日志
12
+ def setup_logger(name, level=logging.INFO):
13
+ """设置日志记录器"""
14
+ logger = logging.getLogger(name)
15
+ # 清除所有已有的handler,避免重复
16
+ if logger.handlers:
17
+ logger.handlers.clear()
18
+
19
+ # 添加新的handler
20
+ handler = logging.StreamHandler()
21
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
+ handler.setFormatter(formatter)
23
+ logger.addHandler(handler)
24
+ logger.setLevel(level)
25
+ # 禁止传播到父logger,避免重复日志
26
+ logger.propagate = False
27
+ return logger
28
+
29
+ # 创建日志记录器
30
+ logger = setup_logger("aggregator")
31
+
32
+ class SentenceCompletionDetector:
33
+ """
34
+ 使用ChatGPT判断句子是否完整
35
+ """
36
+ def __init__(self, model="gpt-3.5-turbo"):
37
+ self.model = model
38
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
39
+
40
+ def build_prompt(self, text: str) -> str:
41
+ return (
42
+ "判断以下语句是否为一句话的结尾,如果是,返回 True,否则返回 False:\n"
43
+ "\"你会学习到如何使用音频数据集,包括音频数据加载\"\n"
44
+ "False\n\n"
45
+ "\"你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法\"\n"
46
+ "True\n\n"
47
+ "\"在开始学习之前,我们需要\"\n"
48
+ "False\n\n"
49
+ "\"在开始学习之前,我们需要了解一些基本概念\"\n"
50
+ "True\n\n"
51
+ "\"第一章,介绍基础知识\"\n"
52
+ "True\n\n"
53
+ f"\"{text}\"\n"
54
+ )
55
+
56
+ def is_sentence_complete(self, text: str) -> bool:
57
+ """
58
+ 判断文本是否是一个完整的句子
59
+ """
60
+ # # 简单规则:如果以标点符号结尾,认为是完整的句子
61
+ # if text.strip() and text.strip()[-1] in "。!?!?.;;":
62
+ # return True
63
+
64
+ # 使用ChatGPT进行更复杂的判断
65
+ prompt = self.build_prompt(text)
66
+ try:
67
+ response = self.client.chat.completions.create(
68
+ model=self.model,
69
+ messages=[
70
+ {"role": "system", "content": "你是一个语言专家,擅长判断句子是否完整。"},
71
+ {"role": "user", "content": prompt}
72
+ ],
73
+ temperature=0.1,
74
+ max_tokens=10,
75
+ )
76
+ result = response.choices[0].message.content.strip()
77
+ logger.debug(f"ChatGPT判断结果: {result}")
78
+ return result.lower() == "true"
79
+ except Exception as e:
80
+ logger.error(f"调用ChatGPT出错: {str(e)}")
81
+ # 出错时使用简单规则判断
82
+ return len(text) > 20 # 如果文本较长,可能是完整句子
83
 
84
  class SemanticAggregator:
85
  """
86
  语义聚合控制器
87
  - 维护segment缓冲池
88
  - 判断是否组成完整语义单元
89
+ - 推送到下游(display/translator
90
  """
91
 
92
  def __init__(
93
  self,
94
+ on_display: Callable[[str, str, str], None],
95
+ on_translate: Callable[[str, str], None],
96
+ transcriber: AudioTranscriber,
97
+ segments_dir: str = "dataset/audio/segments",
98
  max_window: float = 5.0,
99
  max_segments: int = 5,
100
  min_gap: float = 0.8,
101
  force_flush_timeout: float = 3.0
102
  ):
103
  """
104
+ :param on_display: 显示回调 (sentence_id, text, state)
105
+ :param on_translate: 翻译回调 (sentence_id, text)
106
+ :param transcriber: 转录器实例
107
+ :param segments_dir: 音频片段目录
108
  :param max_window: 最大聚合时长(秒)
109
  :param max_segments: 最大聚合片段数
110
  :param min_gap: 触发聚合的最小间隔(秒)
111
  :param force_flush_timeout: 强制flush超时时间(秒)
112
  """
113
  self.buffer: List[TranscriptionResult] = []
114
+ self.on_display = on_display
115
+ self.on_translate = on_translate
116
+ self.transcriber = transcriber
117
+ self.segments_dir = segments_dir
118
  self.max_window = max_window
119
  self.max_segments = max_segments
120
  self.min_gap = min_gap
121
  self.force_flush_timeout = force_flush_timeout
122
  self.last_flush_time = time.time()
123
+ self.sentence_detector = SentenceCompletionDetector()
124
+ self.audio_cache: Dict[int, np.ndarray] = {} # 缓存音频数据,避免重复读取
125
+ self.sample_rate = 16000 # 假设采样率为16kHz
126
+ logger.debug(f"语义聚合器初始化完成,参数: max_window={max_window}, max_segments={max_segments}")
127
 
128
  def add_segment(self, result: TranscriptionResult):
129
  """
130
  新增转写片段到缓冲池,自动判断是否聚合
131
  """
132
  self.buffer.append(result)
133
+ logger.debug(f"添加片段: {result.text}")
134
  if self._should_aggregate():
135
  self._aggregate_and_flush()
136
  elif time.time() - self.last_flush_time > self.force_flush_timeout:
137
+ logger.debug(f"超时强制刷新: {self.force_flush_timeout}秒")
138
  self.flush(force=True)
139
 
140
  def flush(self, force: bool = False):
 
142
  强制输出当前聚合内容
143
  """
144
  if self.buffer:
145
+ logger.debug(f"强制刷新缓冲区,当前片段数: {len(self.buffer)}")
146
  self._aggregate_and_flush()
147
  self.last_flush_time = time.time()
148
 
 
152
  """
153
  if not self.buffer:
154
  return False
155
+
156
+ # 1. 使用ChatGPT判断是否是完整句子
157
+ # 使用逗号连接segments,与_aggregate_and_flush保持一致
158
+ segments = [seg.text for seg in self.buffer]
159
+ combined_text = ",".join(segments)
160
+ if self.sentence_detector.is_sentence_complete(combined_text):
161
+ logger.info(f"检测到完整句子: {combined_text}")
162
  return True
163
+
164
  # 2. segment间隔
165
  if len(self.buffer) >= 2:
166
  gap = self.buffer[-1].start_time - self.buffer[-2].end_time
167
  if gap > self.min_gap:
168
+ logger.info(f"检测到较大间隔: {gap:.2f}秒")
169
  return True
170
+
171
  # 3. 最大窗口/片段数
172
  total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
173
+ if total_duration > self.max_window:
174
+ logger.info(f"达到最大时间窗口: {total_duration:.2f}秒")
175
+ return True
176
+ if len(self.buffer) >= self.max_segments:
177
+ logger.info(f"达到最大片段数: {len(self.buffer)}")
178
  return True
179
+
180
  return False
181
 
182
+ def _get_segment_audio(self, segment_index: int) -> np.ndarray:
183
+ """
184
+ 获取指定索引的音频片段数据
185
+ """
186
+ if segment_index in self.audio_cache:
187
+ return self.audio_cache[segment_index]
188
+
189
+ # 读取音频文件
190
+ audio_path = os.path.join(self.segments_dir, f"test1_segment_{segment_index}.wav")
191
+ try:
192
+ audio_data, sample_rate = sf.read(audio_path)
193
+ self.audio_cache[segment_index] = audio_data
194
+ logger.debug(f"读取音频片段: {audio_path}, 长度: {len(audio_data)/sample_rate:.2f}秒")
195
+ return audio_data
196
+ except Exception as e:
197
+ logger.error(f"读取音频文件失败: {audio_path}, 错误: {str(e)}")
198
+ return np.array([])
199
+
200
+ def _combine_audio_segments(self, segment_indices: List[int]) -> Tuple[np.ndarray, float]:
201
+ """
202
+ 合并多个音频片段
203
+ 返回: (合并后的音频数据, 起始时间)
204
+ """
205
+ if not segment_indices:
206
+ return np.array([]), 0.0
207
+
208
+ # 获取所有片段的音频数据
209
+ audio_segments = []
210
+ for idx in segment_indices:
211
+ audio_data = self._get_segment_audio(idx)
212
+ if len(audio_data) > 0:
213
+ audio_segments.append(audio_data)
214
+
215
+ if not audio_segments:
216
+ return np.array([]), 0.0
217
+
218
+ # 合并音频数据
219
+ combined_audio = np.concatenate(audio_segments)
220
+
221
+ # 获取第一个片段的起始时间
222
+ first_segment = self.buffer[0]
223
+ start_time = first_segment.start_time
224
+
225
+ logger.debug(f"合并音频片段: {segment_indices}, 总长度: {len(combined_audio)/self.sample_rate:.2f}秒")
226
+ return combined_audio, start_time
227
+
228
+ def _retranscribe_segments(self, segment_indices: List[int]) -> List[TranscriptionResult]:
229
+ """
230
+ 重新转录合并后的音频片段
231
+ """
232
+ combined_audio, start_time = self._combine_audio_segments(segment_indices)
233
+ if len(combined_audio) == 0:
234
+ logger.warning("没有有效的音频数据可以重新转录")
235
+ return []
236
+
237
+ logger.debug(f"重新转录合并的音频片段, 长度: {len(combined_audio)/self.sample_rate:.2f}秒")
238
+ try:
239
+ results = self.transcriber.transcribe_segment(combined_audio, start_time=start_time)
240
+ logger.debug(f"重新转录结果: {len(results)}条")
241
+ return results
242
+ except Exception as e:
243
+ logger.error(f"重新转录失败: {str(e)}")
244
+ return []
245
+
246
  def _aggregate_and_flush(self):
247
  """
248
  聚合并推送到下游
249
  """
250
+ if not self.buffer:
251
+ return
252
+
253
+ # 获取所有片段的索引
254
+ segment_indices = []
255
+ for seg in self.buffer:
256
+ if hasattr(seg, 'segment_index') and seg.segment_index is not None:
257
+ if isinstance(seg.segment_index, list):
258
+ segment_indices.extend(seg.segment_index)
259
+ else:
260
+ segment_indices.append(seg.segment_index)
261
+
262
+ # 去重并排序
263
+ segment_indices = sorted(list(set(segment_indices)))
264
+
265
+ # 生成句子ID
266
  sentence_id = str(uuid.uuid4())
267
+
268
+ # 1. 先使用原始文本进行输出,在segment之间添加逗号
269
+ original_segments = [seg.text for seg in self.buffer]
270
+ # 使用逗号连接segments,但不在最后添加句号
271
+ original_text = ",".join(original_segments)
272
+ logger.info(f"原始聚合文本: {original_text}")
273
+ self.on_display(sentence_id, original_text, "raw")
274
+
275
+ # 2. 重新转录
276
+ if segment_indices:
277
+ retranscribed_results = self._retranscribe_segments(segment_indices)
278
+ if retranscribed_results:
279
+ # 合并重新转录的结果,在segment之间添加逗号
280
+ retranscribed_segments = [res.text for res in retranscribed_results]
281
+ retranscribed_text = ",".join(retranscribed_segments)
282
+ logger.info(f"重新转录文本: {retranscribed_text}")
283
+
284
+ # 如果重新转录的结果与原始文本不同,则更新显示
285
+ if retranscribed_text != original_text:
286
+ self.on_display(sentence_id, retranscribed_text, "optimized")
287
+
288
+ # 发送到翻译模块
289
+ self.on_translate(sentence_id, retranscribed_text)
290
+ else:
291
+ # 如果重新转录失败,使用原始文本进行翻译
292
+ logger.warning("重新转录失败,使用原始文本进行翻译")
293
+ self.on_translate(sentence_id, original_text)
294
+ else:
295
+ # 如果没有有效的片段索引,使用原始文本进行翻译
296
+ logger.warning("没有有效的片段索引,使用原始文本进行翻译")
297
+ self.on_translate(sentence_id, original_text)
298
+
299
+ # 清空缓冲区
300
+ buffer_size = len(self.buffer)
301
  self.buffer.clear()
302
  self.last_flush_time = time.time()
303
+ logger.debug(f"清空缓冲区,释放 {buffer_size} 个片段")
304
 
305
 
306
+ def load_transcription_results(json_path):
307
+ """从JSON文件加载转录结果"""
308
+ import json
309
+ with open(json_path, 'r', encoding='utf-8') as f:
310
+ data = json.load(f)
311
+
312
+ results = []
313
+ for segment in data['segments']:
314
+ result = TranscriptionResult(
315
+ text=segment['text'],
316
+ start_time=segment['start_time'],
317
+ end_time=segment['end_time'],
318
+ confidence=segment['confidence'],
319
+ verified=segment['verified'],
320
+ verified_text=segment['verified_text'],
321
+ verification_notes=segment['verification_notes'],
322
+ segment_index=segment['segment_index'] if 'segment_index' in segment else None
323
+ )
324
+ results.append(result)
325
+
326
+ return results
327
+
328
  if __name__ == "__main__":
329
+ """测试聚合器功能"""
330
+ import os
331
+ import sys
332
+ import json
333
+ from pathlib import Path
334
+
335
+ # 配置日志级别
336
+ logger.setLevel(logging.DEBUG)
337
+
338
+ # 检查OpenAI API密钥
339
+ if not os.getenv("OPENAI_API_KEY"):
340
+ logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
341
+
342
+ # 初始化各个模块
343
  from display.display import OutputRenderer
344
+ from translator.translator import NLLBTranslator
345
+ from transcribe.transcribe import AudioTranscriber
346
+
347
+ # 初始化显示器
348
  renderer = OutputRenderer()
349
+
350
+ # 初始化转录器
351
+ try:
352
+ transcriber = AudioTranscriber(model="small", device="cuda", compute_type="int8")
353
+ logger.info("使用GPU进行转录")
354
+ except Exception as e:
355
+ logger.warning(f"GPU初始化失败,使用CPU: {str(e)}")
356
+ transcriber = AudioTranscriber(model="small", device="cpu", compute_type="float32")
357
+
358
+ # 初始化翻译器(可选)
359
+ try:
360
+ translator = NLLBTranslator()
361
+ translation_enabled = True
362
+ except Exception as e:
363
+ logger.warning(f"翻译器初始化失败: {str(e)}")
364
+ translation_enabled = False
365
+
366
+ # 回调函数
367
+ def display_callback(sentence_id, text, state):
368
+ renderer.display(sentence_id, text, state)
369
+
370
+ def translate_callback(sentence_id, text):
371
+ if translation_enabled:
372
+ try:
373
+ translation = translator.translate(text)
374
+ logger.info(f"[翻译] 句子 {sentence_id}: {translation}")
375
+ except Exception as e:
376
+ logger.error(f"翻译失败: {str(e)}")
377
+ else:
378
+ logger.info(f"[翻译已禁用] 句子 {sentence_id}: {text}")
379
+
380
+ # 初始化聚合器
381
+ aggregator = SemanticAggregator(
382
+ on_display=display_callback,
383
+ on_translate=translate_callback,
384
+ transcriber=transcriber,
385
+ segments_dir="dataset/audio/segments",
386
+ max_window=10.0, # 增大窗口以便测试
387
+ max_segments=10, # 增大片段数以便测试
388
+ force_flush_timeout=5.0 # 增大超时以便测试
389
+ )
390
+
391
+ # 加载测试数据
392
+ test_file = "dataset/transcripts/test1_segment_1_20250423_201934.json"
393
+ try:
394
+ results = load_transcription_results(test_file)
395
+ logger.info(f"加载了 {len(results)} 条转录结果")
396
+ except Exception as e:
397
+ logger.error(f"加载转录结果失败: {str(e)}")
398
+ sys.exit(1)
399
+
400
+ # 模拟添加转录结果
401
+ for i, result in enumerate(results):
402
+ logger.info(f"添加第 {i+1}/{len(results)} 条转录结果: {result.text}")
403
+ aggregator.add_segment(result)
404
+ # 模拟处理延迟
405
+ # time.sleep(0.5)
406
+
407
+ # 强制刷新缓冲区
408
+ aggregator.flush(force=True)
409
+
410
+ logger.info("测试完成")
dataset/transcripts/test1_segment_1_20250423_201934.json CHANGED
@@ -12,7 +12,7 @@
12
  "verification_notes": null
13
  },
14
  {
15
- "text": "音频数据出来",
16
  "start_time": 4.34,
17
  "end_time": 5.56,
18
  "confidence": 0.4482421875,
@@ -84,7 +84,7 @@
84
  "verification_notes": null
85
  },
86
  {
87
- "text": "包括波形,彩虹率和冰普渡",
88
  "start_time": 26.28,
89
  "end_time": 28.28,
90
  "confidence": 0.732666015625,
@@ -111,7 +111,7 @@
111
  "verification_notes": null
112
  },
113
  {
114
- "text": "高效加载大规模音频数据集的流逝加载方法。",
115
  "start_time": 33.54,
116
  "end_time": 36.5,
117
  "confidence": 0.88739013671875,
@@ -138,7 +138,7 @@
138
  "verification_notes": null
139
  },
140
  {
141
- "text": "基础的音频相关数",
142
  "start_time": 40.86,
143
  "end_time": 42.4,
144
  "confidence": 0.609619140625,
 
12
  "verification_notes": null
13
  },
14
  {
15
+ "text": "音频数据处理",
16
  "start_time": 4.34,
17
  "end_time": 5.56,
18
  "confidence": 0.4482421875,
 
84
  "verification_notes": null
85
  },
86
  {
87
+ "text": "包括波形,采样率和频谱图",
88
  "start_time": 26.28,
89
  "end_time": 28.28,
90
  "confidence": 0.732666015625,
 
111
  "verification_notes": null
112
  },
113
  {
114
+ "text": "高效加载大规模音频数据集的流式加载方法。",
115
  "start_time": 33.54,
116
  "end_time": 36.5,
117
  "confidence": 0.88739013671875,
 
138
  "verification_notes": null
139
  },
140
  {
141
+ "text": "基础的音频相关术语",
142
  "start_time": 40.86,
143
  "end_time": 42.4,
144
  "confidence": 0.609619140625,
display/display.py CHANGED
@@ -1,32 +1,87 @@
 
 
 
 
1
  from rich.console import Console
2
  from rich.text import Text
3
  from typing import Literal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
5
  console = Console()
6
 
7
  class OutputRenderer:
 
 
 
 
 
8
  def __init__(self):
 
9
  self.history = {} # 用于更新同一条句子的优化内容
 
10
 
11
  def display(self, sentence_id: str, text: str, state: Literal["raw", "optimized"]):
 
 
 
 
 
 
 
12
  if state == "raw":
13
  styled_text = Text(text, style="dim") # 灰色表示原始输出
 
14
  elif state == "optimized":
15
  styled_text = Text(text, style="bold black") # 深黑色加粗
 
16
  else:
17
- raise ValueError("Unknown output state")
 
18
 
19
  # 打印新内容(或替换历史)
20
  if sentence_id in self.history:
21
  console.print(f"[更新] 句子 {sentence_id}:", styled_text)
 
22
  else:
23
  console.print(f"[输出] 句子 {sentence_id}:", styled_text)
 
24
 
 
25
  self.history[sentence_id] = text
 
 
26
 
27
  if __name__ == "__main__":
 
 
 
 
28
  renderer = OutputRenderer()
29
 
 
30
  renderer.display("s1", "I think we should start the meeting now.", "raw")
 
31
  # 模拟优化回填
32
  renderer.display("s1", "I believe it's time to begin the meeting.", "optimized")
 
1
+ """
2
+ 显示模块 - 负责将转写结果显示给用户
3
+ """
4
+
5
  from rich.console import Console
6
  from rich.text import Text
7
  from typing import Literal
8
+ import logging
9
+
10
+ # 配置日志
11
+ def setup_logger(name, level=logging.INFO):
12
+ """设置日志记录器"""
13
+ logger = logging.getLogger(name)
14
+ # 清除所有已有的handler,避免重复
15
+ if logger.handlers:
16
+ logger.handlers.clear()
17
+
18
+ # 添加新的handler
19
+ handler = logging.StreamHandler()
20
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
+ handler.setFormatter(formatter)
22
+ logger.addHandler(handler)
23
+ logger.setLevel(level)
24
+ # 禁止传播到父logger,避免重复日志
25
+ logger.propagate = False
26
+ return logger
27
 
28
+ # 创建日志记录器
29
+ logger = setup_logger("display")
30
+
31
+ # 创建控制台对象
32
  console = Console()
33
 
34
  class OutputRenderer:
35
+ """
36
+ 输出渲染器,负责将转写结果显示给用户
37
+ 支持原始文本和优化后文本的不同样式显示
38
+ """
39
+
40
  def __init__(self):
41
+ """初始化输出渲染器"""
42
  self.history = {} # 用于更新同一条句子的优化内容
43
+ logger.debug("输出渲染器初始化完成")
44
 
45
  def display(self, sentence_id: str, text: str, state: Literal["raw", "optimized"]):
46
+ """
47
+ 显示转写结果
48
+
49
+ :param sentence_id: 句子ID
50
+ :param text: 文本内容
51
+ :param state: 状态,raw表示原始文本,optimized表示优化后文本
52
+ """
53
  if state == "raw":
54
  styled_text = Text(text, style="dim") # 灰色表示原始输出
55
+ logger.debug(f"显示原始文本: {sentence_id}")
56
  elif state == "optimized":
57
  styled_text = Text(text, style="bold black") # 深黑色加粗
58
+ logger.debug(f"显示优化文本: {sentence_id}")
59
  else:
60
+ logger.error(f"未知的输出状态: {state}")
61
+ raise ValueError(f"未知的输出状态: {state}")
62
 
63
  # 打印新内容(或替换历史)
64
  if sentence_id in self.history:
65
  console.print(f"[更新] 句子 {sentence_id}:", styled_text)
66
+ logger.info(f"更新句子: {sentence_id}")
67
  else:
68
  console.print(f"[输出] 句子 {sentence_id}:", styled_text)
69
+ logger.info(f"输出句子: {sentence_id}")
70
 
71
+ # 记录历史
72
  self.history[sentence_id] = text
73
+ logger.debug(f"句子内容: {text}")
74
+
75
 
76
  if __name__ == "__main__":
77
+ # 设置日志级别为DEBUG以查看详细信息
78
+ logger.setLevel(logging.DEBUG)
79
+
80
+ # 测试代码
81
  renderer = OutputRenderer()
82
 
83
+ # 显示原始文本
84
  renderer.display("s1", "I think we should start the meeting now.", "raw")
85
+
86
  # 模拟优化回填
87
  renderer.display("s1", "I believe it's time to begin the meeting.", "optimized")
main.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 伪流式音频转写 + LLM优化系统 主程序
3
+
4
+ 这个程序实现了完整的音频处理流水线,包括:
5
+ 1. VAD分段
6
+ 2. Whisper转录
7
+ 3. 语义聚合
8
+ 4. 即时输出
9
+ 5. LLM优化
10
+ 6. 翻译
11
+
12
+ 使用方法:
13
+ python main.py [--audio_path AUDIO_PATH] [--use_gpu] [--enable_translation] [--enable_optimization]
14
+ """
15
+
16
+ import os
17
+ import sys
18
+ import time
19
+ import logging
20
+ import argparse
21
+ import numpy as np
22
+ import soundfile as sf
23
+ from pathlib import Path
24
+ from typing import List, Dict, Optional, Tuple, Union
25
+ import uuid
26
+
27
+ # 配置日志
28
+ def setup_logger(name, level=logging.INFO):
29
+ """设置日志记录器"""
30
+ logger = logging.getLogger(name)
31
+ # 清除所有已有的handler,避免重复
32
+ if logger.handlers:
33
+ logger.handlers.clear()
34
+
35
+ # 添加新的handler
36
+ handler = logging.StreamHandler()
37
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
38
+ handler.setFormatter(formatter)
39
+ logger.addHandler(handler)
40
+ logger.setLevel(level)
41
+ # 禁止传播到父logger,避免重复日志
42
+ logger.propagate = False
43
+ return logger
44
+
45
+ # 创建主日志记录器
46
+ logger = setup_logger("main")
47
+
48
+ # 导入各个模块
49
+ from vad import VoiceActivityDetector
50
+ from transcribe.transcribe import AudioTranscriber, TranscriptionResult
51
+ from aggregator.semantic_aggregator import SemanticAggregator
52
+ from display.display import OutputRenderer
53
+ from optimizer.dispatcher import OptimizationDispatcher
54
+ from translator.translator import NLLBTranslator
55
+
56
+ class AudioProcessingPipeline:
57
+ """完整的音频处理流水线"""
58
+
59
+ def __init__(
60
+ self,
61
+ audio_path: str,
62
+ use_gpu: bool = True,
63
+ enable_translation: bool = True,
64
+ enable_optimization: bool = True,
65
+ whisper_model: str = "large",
66
+ log_level: Union[int, str] = logging.INFO
67
+ ):
68
+ """
69
+ 初始化处理流水线
70
+
71
+ :param audio_path: 音频文件路径
72
+ :param use_gpu: 是否使用GPU
73
+ :param enable_translation: 是否启用翻译
74
+ :param enable_optimization: 是否启用LLM优化
75
+ :param whisper_model: Whisper模型大小 (tiny, base, small, medium, large)
76
+ :param log_level: 日志级别
77
+ """
78
+ # 设置日志级别
79
+ if isinstance(log_level, str):
80
+ log_level = getattr(logging, log_level.upper())
81
+ logger.setLevel(log_level)
82
+
83
+ self.audio_path = audio_path
84
+ self.use_gpu = use_gpu
85
+ self.enable_translation = enable_translation
86
+ self.enable_optimization = enable_optimization
87
+ self.whisper_model = whisper_model
88
+
89
+ # 检查设备
90
+ self.device = "cuda" if use_gpu and self._is_gpu_available() else "cpu"
91
+ logger.info(f"使用设备: {self.device}")
92
+ logger.debug(f"配置: whisper_model={whisper_model}, translation={enable_translation}, optimization={enable_optimization}")
93
+
94
+ # 初始化各个模块
95
+ self._init_modules()
96
+
97
+ # 加载音频数据
98
+ self.audio_data, self.sample_rate = sf.read(audio_path)
99
+ logger.info(f"加载音频: {os.path.basename(audio_path)}, 长度: {len(self.audio_data)/self.sample_rate:.2f}秒")
100
+ logger.debug(f"音频详情: 采样率={self.sample_rate}Hz, 形状={self.audio_data.shape}")
101
+
102
+ # 存储句子ID到优化任务的映射
103
+ self.optimization_tasks: Dict[str, str] = {}
104
+
105
+ def _is_gpu_available(self) -> bool:
106
+ """检查GPU是否可用"""
107
+ try:
108
+ import torch
109
+ if torch.cuda.is_available():
110
+ logger.debug(f"检测到GPU: {torch.cuda.get_device_name(0)}")
111
+ return True
112
+ return False
113
+ except ImportError:
114
+ logger.debug("未检测到PyTorch,将使用CPU")
115
+ return False
116
+
117
+ def _init_modules(self):
118
+ """初始化各个处理模块"""
119
+ # 1. 初始化VAD
120
+ logger.debug("初始化VAD模块...")
121
+ self.vad = VoiceActivityDetector(save_audio=True, save_json=True)
122
+
123
+ # 2. 初始化转录器
124
+ logger.debug(f"初始化Whisper转录模块 (model={self.whisper_model}, device={self.device})...")
125
+ self.transcriber = AudioTranscriber(
126
+ model=self.whisper_model,
127
+ device=self.device,
128
+ compute_type="int8" if self.device == "cuda" else "float32"
129
+ )
130
+
131
+ # 3. 初始化显示器
132
+ logger.debug("初始化显示模块...")
133
+ self.renderer = OutputRenderer()
134
+
135
+ # 4. 初始化优化调度器(如果启用)
136
+ if self.enable_optimization:
137
+ logger.debug("初始化LLM优化调度器...")
138
+ self.optimizer = OptimizationDispatcher(
139
+ max_workers=2,
140
+ callback=self._optimization_callback
141
+ )
142
+ else:
143
+ logger.debug("LLM优化已禁用")
144
+ self.optimizer = None
145
+
146
+ # 5. 初始化翻译器(如果启用)
147
+ if self.enable_translation:
148
+ logger.debug("初始化翻译模块...")
149
+ try:
150
+ self.translator = NLLBTranslator()
151
+ self.translation_enabled = True
152
+ except Exception as e:
153
+ logger.warning(f"翻译器初始化失败: {str(e)}")
154
+ self.translation_enabled = False
155
+ else:
156
+ logger.debug("翻译已禁用")
157
+ self.translation_enabled = False
158
+ self.translator = None
159
+
160
+ # 6. 初始化聚合器
161
+ logger.debug("初始化语义聚合控制器...")
162
+ self.aggregator = SemanticAggregator(
163
+ on_display=self._display_callback,
164
+ on_translate=self._translate_callback,
165
+ transcriber=self.transcriber,
166
+ segments_dir="dataset/audio/segments",
167
+ max_window=5.0,
168
+ max_segments=5,
169
+ min_gap=0.8,
170
+ force_flush_timeout=3.0
171
+ )
172
+
173
+ logger.info("所有模块初始化完成")
174
+
175
+ def _display_callback(self, sentence_id: str, text: str, state: str):
176
+ """显示回调函数"""
177
+ self.renderer.display(sentence_id, text, state)
178
+
179
+ # 如果启用了优化,且是原始文本,则提交优化任务
180
+ if self.enable_optimization and state == "raw" and self.optimizer:
181
+ logger.debug(f"提交优化任务: {sentence_id}")
182
+ self.optimizer.submit(sentence_id, text)
183
+
184
+ def _translate_callback(self, sentence_id: str, text: str):
185
+ """翻译回调函数"""
186
+ if self.translation_enabled and self.translator:
187
+ try:
188
+ # 翻译模块内部已经记录了原文和结果,这里只需调用翻译方法
189
+ self.translator.translate(text)
190
+ logger.debug(f"已翻译句子: {sentence_id}")
191
+ except Exception as e:
192
+ logger.error(f"翻译失败: {str(e)}")
193
+
194
+ def _optimization_callback(self, sentence_id: str, original_text: str, optimized_text: str):
195
+ """优化回调函数"""
196
+ logger.debug(f"收到优化结果: {sentence_id}")
197
+ # 更新显示
198
+ self.renderer.display(sentence_id, optimized_text, "optimized")
199
+ # 如果启用了翻译,则翻译优化后的文本
200
+ if self.translation_enabled:
201
+ logger.debug(f"翻译优化后的文本: {sentence_id}")
202
+ self._translate_callback(sentence_id, optimized_text)
203
+
204
+ def process(self):
205
+ """处理音频文件"""
206
+ logger.info("开始处理音频...")
207
+
208
+ # 1. VAD分段
209
+ logger.debug("执行VAD分段...")
210
+ segments = self.vad.detect_voice_segments(self.audio_data, self.sample_rate)
211
+ logger.info(f"VAD分段完成: {len(segments)}个片段")
212
+
213
+ # 2. 转录每个片段
214
+ for i, (start, end) in enumerate(segments):
215
+ logger.debug(f"转录片段 {i+1}/{len(segments)}: {start:.2f}s -> {end:.2f}s")
216
+
217
+ # 提取片段音频数据
218
+ segment_audio = self.audio_data[int(start * self.sample_rate):int(end * self.sample_rate)]
219
+
220
+ # 转录片段
221
+ results = self.transcriber.transcribe_segment(segment_audio, start_time=start)
222
+
223
+ # 添加片段索引
224
+ for result in results:
225
+ result.segment_index = i + 1 # 片段索引从1开始
226
+
227
+ # 添加到聚合器
228
+ for result in results:
229
+ logger.debug(f"添加转录结果: {result.text}")
230
+ self.aggregator.add_segment(result)
231
+
232
+ # 模拟处理延迟
233
+ time.sleep(0.1)
234
+
235
+ # 3. 最后强制刷新缓冲区
236
+ logger.debug("强制刷新缓冲区...")
237
+ self.aggregator.flush(force=True)
238
+
239
+ # 4. 等待所有优化任务完成
240
+ if self.enable_optimization and self.optimizer:
241
+ logger.debug("等待所有优化任务完成...")
242
+ self.optimizer.wait_until_done()
243
+
244
+ logger.info("音频处理完成")
245
+
246
+ def parse_args():
247
+ """解析命令行参数"""
248
+ parser = argparse.ArgumentParser(description="伪流式音频转写 + LLM优化系统")
249
+ parser.add_argument("--audio_path", type=str, default="dataset/audio/test1.wav",
250
+ help="音频文件路径")
251
+ parser.add_argument("--use_gpu", action="store_true", default=True,
252
+ help="是否使用GPU")
253
+ parser.add_argument("--enable_translation", action="store_true", default=True,
254
+ help="是否启用翻译")
255
+ parser.add_argument("--enable_optimization", action="store_true", default=True,
256
+ help="是否启用LLM优化")
257
+ parser.add_argument("--whisper_model", type=str, default="small",
258
+ choices=["tiny", "base", "small", "medium", "large"],
259
+ help="Whisper模型大小")
260
+ parser.add_argument("--log_level", type=str, default="INFO",
261
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
262
+ help="日志级别")
263
+ return parser.parse_args()
264
+
265
+ def main():
266
+ """主函数"""
267
+ # 解析命令行参数
268
+ args = parse_args()
269
+
270
+ # 设置日志级别
271
+ log_level = getattr(logging, args.log_level)
272
+
273
+ # 设置所有模块的日志级别
274
+ for module in ["main", "vad", "transcribe", "aggregator", "display", "optimizer", "translator"]:
275
+ setup_logger(module, log_level)
276
+
277
+ # 检查OpenAI API密钥(用于句子完整性判断和优化)
278
+ if not os.getenv("OPENAI_API_KEY") and args.enable_optimization:
279
+ logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
280
+
281
+ # 检查音频文件是否存在
282
+ if not os.path.exists(args.audio_path):
283
+ logger.error(f"音频文件不存在: {args.audio_path}")
284
+ return
285
+
286
+ # 创建并运行处理流水线
287
+ pipeline = AudioProcessingPipeline(
288
+ audio_path=args.audio_path,
289
+ use_gpu=args.use_gpu,
290
+ enable_translation=args.enable_translation,
291
+ enable_optimization=args.enable_optimization,
292
+ whisper_model=args.whisper_model,
293
+ log_level=log_level
294
+ )
295
+
296
+ # 处理音频
297
+ pipeline.process()
298
+
299
+ if __name__ == "__main__":
300
+ main()
optimizer/dispatcher.py CHANGED
@@ -1,46 +1,127 @@
1
- # optimizer/dispatcher.py
 
 
 
2
  import asyncio
 
3
  from concurrent.futures import ThreadPoolExecutor
4
- from optimizer.llm_runner import TinyLLaMARunner
 
5
  from optimizer.optimize_task import OptimizeTask
6
 
7
- class OptimizationDispatcher:
8
- def __init__(self, max_workers: int = 1):
9
- self.queue = asyncio.Queue()
10
- self.executor = ThreadPoolExecutor(max_workers=max_workers)
11
- self.model_runner = TinyLLaMARunner()
12
-
13
- def submit(self, sentence_id: str, text: str, callback):
14
- task = OptimizeTask(sentence_id, text, callback)
15
- self.queue.put_nowait(task)
 
 
 
 
 
 
 
 
16
 
17
- async def start(self):
18
- print("[Dispatcher] 启动优化调度器...")
19
- while True:
20
- task: OptimizeTask = await self.queue.get()
21
- asyncio.create_task(self._handle(task))
22
 
23
- async def _handle(self, task: OptimizeTask):
24
- await asyncio.get_event_loop().run_in_executor(
25
- self.executor,
26
- task.run,
27
- self.model_runner
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  if __name__ == "__main__":
32
- import time
33
-
34
- def test_callback(sid, result):
35
- print(f"[回填] {sid}: {result}")
36
-
37
- async def main():
38
- dispatcher = OptimizationDispatcher()
39
- asyncio.create_task(dispatcher.start())
40
-
41
- dispatcher.submit("s001", "we maybe start tomorrow okay", test_callback)
42
- dispatcher.submit("s002", "they need eat fast meeting now", test_callback)
43
-
44
- await asyncio.sleep(5)
45
-
46
- asyncio.run(main())
 
 
 
 
1
+ """
2
+ 优化调度器 - 负责管理LLM优化任务队列
3
+ """
4
+
5
  import asyncio
6
+ import logging
7
  from concurrent.futures import ThreadPoolExecutor
8
+ from typing import Callable, Optional
9
+ from optimizer.llm_api_runner import ChatGPTRunner
10
  from optimizer.optimize_task import OptimizeTask
11
 
12
+ # 配置日志
13
+ def setup_logger(name, level=logging.INFO):
14
+ """设置日志记录器"""
15
+ logger = logging.getLogger(name)
16
+ # 清除所有已有的handler,避免重复
17
+ if logger.handlers:
18
+ logger.handlers.clear()
19
+
20
+ # 添加新的handler
21
+ handler = logging.StreamHandler()
22
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
23
+ handler.setFormatter(formatter)
24
+ logger.addHandler(handler)
25
+ logger.setLevel(level)
26
+ # 禁止传播到父logger,避免重复日志
27
+ logger.propagate = False
28
+ return logger
29
 
30
+ # 创建日志记录器
31
+ logger = setup_logger("optimizer")
 
 
 
32
 
33
+ class OptimizationDispatcher:
34
+ """
35
+ 优化调度器,负责管理LLM优化任务队列
36
+ 支持异步处理多个优化任务
37
+ """
38
+
39
+ def __init__(self, max_workers: int = 2, callback: Optional[Callable] = None):
40
+ """
41
+ 初始化优化调度器
42
+
43
+ :param max_workers: 最大工作线程数
44
+ :param callback: 优化完成后的回调函数
45
+ """
46
+ self.tasks = {} # 存储任务ID到任务的映射
47
+ self.executor = ThreadPoolExecutor(max_workers=max_workers)
48
+ self.model_runner = ChatGPTRunner()
49
+ self.callback = callback
50
+ logger.debug(f"优化调度器初始化完成,最大工作线程数: {max_workers}")
51
+
52
+ def submit(self, sentence_id: str, text: str, callback: Optional[Callable] = None):
53
+ """
54
+ 提交优化任务
55
+
56
+ :param sentence_id: 句子ID
57
+ :param text: 需要优化的文本
58
+ :param callback: 优化完成后的回调函数,如果为None则使用默认回调
59
+ """
60
+ task_callback = callback or self.callback
61
+ task = OptimizeTask(sentence_id, text, task_callback)
62
+ self.tasks[sentence_id] = task
63
+ logger.debug(f"提交优化任务: {sentence_id}")
64
+
65
+ # 在线程池中执行任务
66
+ self.executor.submit(self._process_task, task)
67
+ logger.debug(f"任务已提交到线程池: {sentence_id}")
68
+
69
+ def _process_task(self, task: OptimizeTask):
70
+ """
71
+ 处理优化任务
72
+
73
+ :param task: 优化任务
74
+ """
75
+ try:
76
+ logger.debug(f"开始处理任务: {task.sentence_id}")
77
+ # 使用模型运行器优化文本
78
+ optimized_text = self.model_runner.optimize(task.text)
79
+ logger.debug(f"任务处理完成: {task.sentence_id}")
80
+
81
+ # 调用回调函数
82
+ if task.callback:
83
+ task.callback(task.sentence_id, task.text, optimized_text)
84
+ logger.debug(f"已调用回调函数: {task.sentence_id}")
85
+
86
+ # 从任务列表中移除
87
+ if task.sentence_id in self.tasks:
88
+ del self.tasks[task.sentence_id]
89
+
90
+ logger.info(f"优化任务完成: {task.sentence_id}")
91
+ except Exception as e:
92
+ logger.error(f"处理任务出错: {task.sentence_id}, 错误: {str(e)}")
93
+
94
+ def wait_until_done(self, timeout: Optional[float] = None):
95
+ """
96
+ 等待所有任务完成
97
+
98
+ :param timeout: 超时时间(秒),如果为None则一直等待
99
+ :return: 是否所有任务都已完成
100
+ """
101
+ logger.debug(f"等待所有任务完成,当前任务数: {len(self.tasks)}")
102
+ self.executor.shutdown(wait=True, timeout=timeout)
103
+ # 创建新的线程池
104
+ self.executor = ThreadPoolExecutor(max_workers=self.executor._max_workers)
105
+ logger.debug("所有任务已完成")
106
+ return True
107
 
108
 
109
  if __name__ == "__main__":
110
+ # 设置日志级别为DEBUG以查看详细信息
111
+ logger.setLevel(logging.DEBUG)
112
+
113
+ # 测试回调函数
114
+ def test_callback(sentence_id, original_text, optimized_text):
115
+ logger.info(f"[回填] {sentence_id}: {optimized_text}")
116
+
117
+ # 创建调度器
118
+ dispatcher = OptimizationDispatcher(callback=test_callback)
119
+
120
+ # 提交测试任务
121
+ dispatcher.submit("s001", "we maybe start tomorrow okay")
122
+ dispatcher.submit("s002", "they need eat fast meeting now")
123
+
124
+ # 等待任务完成
125
+ dispatcher.wait_until_done()
126
+
127
+ logger.info("测试完成")
optimizer/llm_api_runner.py CHANGED
@@ -1,15 +1,61 @@
1
- # optimizer/llm_api_runner.py
 
 
 
2
  from openai import OpenAI
3
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- MODEL_NAME = "gpt-3.5-turbo" # 可改为 "gpt-3.5-turbo"
 
6
 
7
  class ChatGPTRunner:
8
- def __init__(self, model="gpt-3.5-turbo"):
 
 
 
 
 
 
 
 
 
9
  self.model = model
10
- self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 
 
 
 
11
 
12
  def build_prompt(self, text: str) -> str:
 
 
 
 
 
 
13
  return (
14
  "示例:\n"
15
  "原句:你门现在就得开会了,别迟到了。\n"
@@ -20,26 +66,60 @@ class ChatGPTRunner:
20
  "修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
21
  "原句:系统将进入留言模式,请耐行等待。\n"
22
  "修改:系统将进入留言模式,请耐心等待。\n"
 
 
23
  f"原句:{text}\n"
24
  f"修改:"
25
  )
26
 
27
  def optimize(self, text: str, max_tokens: int = 256) -> str:
 
 
 
 
 
 
 
 
 
 
 
28
  prompt = self.build_prompt(text)
29
- response = self.client.chat.completions.create(
30
- model=self.model,
31
- messages=[
32
- {"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"},
33
- {"role": "user", "content": prompt}
34
- ],
35
- temperature=0.4,
36
- max_tokens=max_tokens,
37
- )
38
- return response.choices[0].message.content.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
 
 
 
 
41
  runner = ChatGPTRunner(MODEL_NAME)
42
  test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
 
 
43
  result = runner.optimize(test_input)
44
- print("优化前:", test_input)
45
- print("优化后:", result)
 
1
+ """
2
+ ChatGPT优化器 - 使用OpenAI API优化转写结果
3
+ """
4
+
5
  from openai import OpenAI
6
  import os
7
+ import logging
8
+ import time
9
+
10
+ # 配置日志
11
+ def setup_logger(name, level=logging.INFO):
12
+ """设置日志记录器"""
13
+ logger = logging.getLogger(name)
14
+ # 清除所有已有的handler,避免重复
15
+ if logger.handlers:
16
+ logger.handlers.clear()
17
+
18
+ # 添加新的handler
19
+ handler = logging.StreamHandler()
20
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
21
+ handler.setFormatter(formatter)
22
+ logger.addHandler(handler)
23
+ logger.setLevel(level)
24
+ # 禁止传播到父logger,避免重复日志
25
+ logger.propagate = False
26
+ return logger
27
+
28
+ # 创建日志记录器
29
+ logger = setup_logger("optimizer.api")
30
 
31
+ # 默认模型
32
+ MODEL_NAME = "gpt-3.5-turbo"
33
 
34
  class ChatGPTRunner:
35
+ """
36
+ ChatGPT优化器,使用OpenAI API优化转写结果
37
+ """
38
+
39
+ def __init__(self, model: str = MODEL_NAME):
40
+ """
41
+ 初始化ChatGPT优化器
42
+
43
+ :param model: 使用的模型名称
44
+ """
45
  self.model = model
46
+ api_key = os.getenv("OPENAI_API_KEY")
47
+ if not api_key:
48
+ logger.warning("未设置OPENAI_API_KEY环境变量")
49
+ self.client = OpenAI(api_key=api_key)
50
+ logger.debug(f"ChatGPT优化器初始化完成,使用模型: {model}")
51
 
52
  def build_prompt(self, text: str) -> str:
53
+ """
54
+ 构建优化提示
55
+
56
+ :param text: 需要优化的文本
57
+ :return: 构建好的提示
58
+ """
59
  return (
60
  "示例:\n"
61
  "原句:你门现在就得开会了,别迟到了。\n"
 
66
  "修改:本章节将为你介绍音频数据的基本概念,包括波形、采样、频谱、图像。\n"
67
  "原句:系统将进入留言模式,请耐行等待。\n"
68
  "修改:系统将进入留言模式,请耐心等待。\n"
69
+ "原句:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。\n"
70
+ "修改:你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法。\n"
71
  f"原句:{text}\n"
72
  f"修改:"
73
  )
74
 
75
  def optimize(self, text: str, max_tokens: int = 256) -> str:
76
+ """
77
+ 优化文本
78
+
79
+ :param text: 需要优化的文本
80
+ :param max_tokens: 最大生成token数
81
+ :return: 优化后的文本
82
+ """
83
+ logger.debug(f"开始优化文本: {text}")
84
+ start_time = time.time()
85
+
86
+ # 构建提示
87
  prompt = self.build_prompt(text)
88
+
89
+ try:
90
+ # 调用API
91
+ response = self.client.chat.completions.create(
92
+ model=self.model,
93
+ messages=[
94
+ {"role": "system", "content": "你是用于优化语音识别的转写结果的校对助手。请保留原始句子的结构,仅修正错别字、语义不通或专业术语使用错误的部分。不要增加、删减或合并句子,务必保留原文的信息表达,仅对用词错误做最小修改。"},
95
+ {"role": "user", "content": prompt}
96
+ ],
97
+ temperature=0.4,
98
+ max_tokens=max_tokens,
99
+ )
100
+
101
+ # 提取结果
102
+ result = response.choices[0].message.content.strip()
103
+
104
+ # 记录耗时
105
+ elapsed_time = time.time() - start_time
106
+ logger.debug(f"优化完成,耗时: {elapsed_time:.2f}秒")
107
+ logger.info(f"优化结果: {result}")
108
+
109
+ return result
110
+ except Exception as e:
111
+ logger.error(f"优化失败: {str(e)}")
112
+ # 出错时返回原文
113
+ return text
114
 
115
  if __name__ == "__main__":
116
+ # 设置日志级别为DEBUG以查看详细信息
117
+ logger.setLevel(logging.DEBUG)
118
+
119
+ # ��试优化
120
  runner = ChatGPTRunner(MODEL_NAME)
121
  test_input = "你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流逝加载方法。"
122
+
123
+ logger.info(f"优化前: {test_input}")
124
  result = runner.optimize(test_input)
125
+ logger.info(f"优化后: {result}")
 
optimizer/optimize_task.py CHANGED
@@ -1,23 +1,74 @@
1
- # optimizer/optimize_task.py
2
- from typing import Callable
3
- from optimizer.llm_runner import TinyLLaMARunner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class OptimizeTask:
6
- def __init__(self, sentence_id: str, text: str, callback: Callable[[str, str], None]):
 
 
 
 
 
 
 
 
 
 
 
7
  self.sentence_id = sentence_id
8
  self.text = text
9
  self.callback = callback
10
-
11
- def run(self, model_runner: TinyLLaMARunner):
12
- optimized_text = model_runner.optimize(self.text)
13
- self.callback(self.sentence_id, optimized_text)
 
14
 
15
 
16
  if __name__ == "__main__":
17
- def fake_callback(sid, text):
18
- print(f"[回调] 优化结果:({sid}) -> {text}")
19
-
 
 
 
 
 
20
  task = OptimizeTask("s001", "they go home maybe tomorrow", fake_callback)
21
- from optimizer.llm_runner import TinyLLaMARunner
22
- runner = TinyLLaMARunner()
23
- task.run(runner)
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 优化任务 - 表示一个LLM优化任务
3
+ """
4
+
5
+ import logging
6
+ from typing import Callable, Optional
7
+
8
+ # 配置日志
9
+ def setup_logger(name, level=logging.INFO):
10
+ """设置日志记录器"""
11
+ logger = logging.getLogger(name)
12
+ # 清除所有已有的handler,避免重复
13
+ if logger.handlers:
14
+ logger.handlers.clear()
15
+
16
+ # 添加新的handler
17
+ handler = logging.StreamHandler()
18
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
19
+ handler.setFormatter(formatter)
20
+ logger.addHandler(handler)
21
+ logger.setLevel(level)
22
+ # 禁止传播到父logger,避免重复日志
23
+ logger.propagate = False
24
+ return logger
25
+
26
+ # 创建日志记录器
27
+ logger = setup_logger("optimizer.task")
28
 
29
  class OptimizeTask:
30
+ """
31
+ 优化任务,表示一个需要LLM优化的文本任务
32
+ """
33
+
34
+ def __init__(self, sentence_id: str, text: str, callback: Optional[Callable[[str, str, str], None]] = None):
35
+ """
36
+ 初始化优化任务
37
+
38
+ :param sentence_id: 句子ID
39
+ :param text: 需要优化的文本
40
+ :param callback: 优化完成后的回调函数,接收参数(sentence_id, original_text, optimized_text)
41
+ """
42
  self.sentence_id = sentence_id
43
  self.text = text
44
  self.callback = callback
45
+ logger.debug(f"创建优化任务: {sentence_id}")
46
+
47
+ def __str__(self):
48
+ """字符串表示"""
49
+ return f"OptimizeTask(id={self.sentence_id}, text={self.text[:20]}...)"
50
 
51
 
52
  if __name__ == "__main__":
53
+ # 设置日志级别为DEBUG以查看详细信息
54
+ logger.setLevel(logging.DEBUG)
55
+
56
+ # 测试回调函数
57
+ def fake_callback(sid, original_text, optimized_text):
58
+ logger.info(f"[回调] 优化结果:({sid}) -> {optimized_text}")
59
+
60
+ # 创建任务
61
  task = OptimizeTask("s001", "they go home maybe tomorrow", fake_callback)
62
+
63
+ # 创建模型运行器
64
+ from optimizer.llm_api_runner import ChatGPTRunner
65
+ runner = ChatGPTRunner()
66
+
67
+ # 优化文本
68
+ optimized_text = runner.optimize(task.text)
69
+
70
+ # 调用回调
71
+ if task.callback:
72
+ task.callback(task.sentence_id, task.text, optimized_text)
73
+
74
+ logger.info("测试完成")
transcribe/transcribe.py CHANGED
@@ -21,11 +21,25 @@ class TranscriptionResult:
21
  segment_index: Optional[int] = None # 添加片段索引字段
22
 
23
  # 配置日志
24
- logger = logging.getLogger("transcribe")
25
- handler = logging.StreamHandler()
26
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27
- handler.setFormatter(formatter)
28
- logger.addHandler(handler)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  class AudioTranscriber:
31
  def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
@@ -44,12 +58,12 @@ class AudioTranscriber:
44
  log_level = getattr(logging, log_level.upper())
45
  logger.setLevel(log_level)
46
 
47
- logger.debug("📥 Loading Whisper model...")
48
 
49
  from faster_whisper import WhisperModel
50
  self.model = WhisperModel(model, device=device, compute_type=compute_type)
51
 
52
- logger.debug("📥 Loading Whisper model successfully!!")
53
 
54
  def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
55
  """
@@ -64,7 +78,6 @@ class AudioTranscriber:
64
  """
65
  start_process_time = time.time()
66
 
67
- logger.debug("Model transcribe...")
68
  logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
69
 
70
  try:
@@ -74,9 +87,9 @@ class AudioTranscriber:
74
 
75
  segments = list(segments_generator)
76
 
77
- logger.debug(f"Model transcribe successfully! Segments count: {len(segments)}")
78
  if len(segments) > 0:
79
- logger.debug(f"First segment: {segments[0]}")
80
 
81
  results = []
82
  for seg in segments:
@@ -141,6 +154,7 @@ class AudioTranscriber:
141
  with open(output_path, 'w', encoding='utf-8') as f:
142
  json.dump(data, f, ensure_ascii=False, indent=2)
143
 
 
144
  return output_path
145
 
146
 
@@ -149,12 +163,11 @@ if __name__ == "__main__":
149
  audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
150
  import soundfile as sf
151
 
152
- # 设置日志级别: DEBUG, INFO, WARNING, ERROR, CRITICAL
153
- # 可以通过字符串或常量设置
154
- processor = AudioTranscriber(log_level="DEBUG") # 或 log_level=logging.INFO
155
 
156
- # 也可以直接设置logger级别
157
- # logger.setLevel(logging.DEBUG) # 查看所有详细日志
158
 
159
  try:
160
  audio_data, sample_rate = sf.read(audio_path)
 
21
  segment_index: Optional[int] = None # 添加片段索引字段
22
 
23
  # 配置日志
24
+ def setup_logger(name, level=logging.INFO):
25
+ """设置日志记录器"""
26
+ logger = logging.getLogger(name)
27
+ # 清除所有已有的handler,避免重复
28
+ if logger.handlers:
29
+ logger.handlers.clear()
30
+
31
+ # 添加新的handler
32
+ handler = logging.StreamHandler()
33
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
34
+ handler.setFormatter(formatter)
35
+ logger.addHandler(handler)
36
+ logger.setLevel(level)
37
+ # 禁止传播到父logger,避免重复日志
38
+ logger.propagate = False
39
+ return logger
40
+
41
+ # 创建日志记录器
42
+ logger = setup_logger("transcribe")
43
 
44
  class AudioTranscriber:
45
  def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
 
58
  log_level = getattr(logging, log_level.upper())
59
  logger.setLevel(log_level)
60
 
61
+ logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}")
62
 
63
  from faster_whisper import WhisperModel
64
  self.model = WhisperModel(model, device=device, compute_type=compute_type)
65
 
66
+ logger.debug("Whisper模型加载完成")
67
 
68
  def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
69
  """
 
78
  """
79
  start_process_time = time.time()
80
 
 
81
  logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
82
 
83
  try:
 
87
 
88
  segments = list(segments_generator)
89
 
90
+ logger.debug(f"转录成功,片段数: {len(segments)}")
91
  if len(segments) > 0:
92
+ logger.debug(f"第一个片段: {segments[0]}")
93
 
94
  results = []
95
  for seg in segments:
 
154
  with open(output_path, 'w', encoding='utf-8') as f:
155
  json.dump(data, f, ensure_ascii=False, indent=2)
156
 
157
+ logger.info(f"转录结果已保存到: {output_path}")
158
  return output_path
159
 
160
 
 
163
  audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
164
  import soundfile as sf
165
 
166
+ # 设置日志级别为DEBUG以查看详细信息
167
+ logger.setLevel(logging.DEBUG)
 
168
 
169
+ # 初始化转录器
170
+ processor = AudioTranscriber(log_level="DEBUG")
171
 
172
  try:
173
  audio_data, sample_rate = sf.read(audio_path)
translator/translator.py CHANGED
@@ -1,20 +1,56 @@
1
- # translator_nllb.py
 
 
2
 
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  from langdetect import detect
5
  import torch
6
  import time
7
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  class NLLBTranslator:
 
 
 
 
10
  def __init__(self, model_name="facebook/nllb-200-distilled-600M", default_target="eng_Latn"):
 
 
 
 
 
 
11
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- print(f"[⚙️ 模型加载中] 使用设备: {self.device}")
 
13
  if self.device.type == "cuda":
14
- print(f"[GPU] 当前设备: {torch.cuda.get_device_name(0)}")
15
  total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
16
- print(f"[GPU] 显存总量: {total_mem:.1f} GB")
17
 
 
 
18
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
20
  model_name,
@@ -22,51 +58,92 @@ class NLLBTranslator:
22
  ).to(self.device)
23
 
24
  self.default_target = default_target
 
25
 
26
  def detect_lang_code(self, text: str) -> str:
 
 
 
 
 
 
27
  try:
28
  lang = detect(text)
 
29
  except Exception:
30
- print("⚠️ 检测失败,默认 zh")
31
  lang = "zh-cn"
32
 
 
33
  lang_map = {
34
  "zh-cn": "zho_Hans", "zh": "zho_Hans", "en": "eng_Latn", "fr": "fra_Latn",
35
  "de": "deu_Latn", "ja": "jpn_Jpan", "ko": "kor_Hang", "ar": "arb_Arab"
36
  }
 
37
  lang_code = lang_map.get(lang.lower(), "eng_Latn")
38
- print(f"[🔍 语言识别] Detected `{lang}`, mapped to `{lang_code}`")
39
  return lang_code
40
 
41
  def translate(self, text: str, target_lang_code: str = None) -> str:
42
- print("\n🌐 [翻译任务启动]")
43
- print(f"原文:{text}")
44
-
 
 
 
 
 
 
 
 
 
 
45
  src_lang = self.detect_lang_code(text)
46
  tgt_lang = target_lang_code or self.default_target
47
 
 
48
  self.tokenizer.src_lang = src_lang
49
  inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
50
  inputs["forced_bos_token_id"] = self.tokenizer.convert_tokens_to_ids(tgt_lang)
51
 
 
52
  start = time.time()
53
  with torch.no_grad():
54
  output = self.model.generate(**inputs, max_new_tokens=80)
 
 
55
  result = self.tokenizer.decode(output[0], skip_special_tokens=True)
56
- print(f"[✅ 翻译完成] {src_lang} → {tgt_lang},耗时:{time.time() - start:.2f}s")
 
 
 
 
 
 
 
57
  return result
58
 
59
 
60
  if __name__ == "__main__":
 
 
 
 
61
  translator = NLLBTranslator()
62
 
 
63
  zh_text = "你会学习到如何使用音频数据集"
64
- print("\n==== 中文 → 英文 ====")
65
- print("翻译结果:", translator.translate(zh_text, target_lang_code="eng_Latn"))
 
66
 
 
67
  en_text = "This audio processing pipeline is fast and accurate."
68
- print("\n==== 英文 → 法语 ====")
69
- print("翻译结果:", translator.translate(en_text, target_lang_code="fra_Latn"))
70
-
71
- print("\n==== 英文 → 阿拉伯语 ====")
72
- print("翻译结果:", translator.translate(en_text, target_lang_code="arb_Arab"))
 
 
 
 
1
+ """
2
+ 翻译模块 - 使用NLLB模型进行多语言翻译
3
+ """
4
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  from langdetect import detect
7
  import torch
8
  import time
9
+ import logging
10
+
11
+ # 配置日志
12
+ def setup_logger(name, level=logging.INFO):
13
+ """设置日志记录器"""
14
+ logger = logging.getLogger(name)
15
+ # 清除所有已有的handler,避免重复
16
+ if logger.handlers:
17
+ logger.handlers.clear()
18
+
19
+ # 添加新的handler
20
+ handler = logging.StreamHandler()
21
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
22
+ handler.setFormatter(formatter)
23
+ logger.addHandler(handler)
24
+ logger.setLevel(level)
25
+ # 禁止传播到父logger,避免重复日志
26
+ logger.propagate = False
27
+ return logger
28
+
29
+ # 创建日志记录器
30
+ logger = setup_logger("translator")
31
 
32
  class NLLBTranslator:
33
+ """
34
+ NLLB翻译器,使用Facebook的NLLB模型进行多语言翻译
35
+ """
36
+
37
  def __init__(self, model_name="facebook/nllb-200-distilled-600M", default_target="eng_Latn"):
38
+ """
39
+ 初始化NLLB翻译器
40
+
41
+ :param model_name: 模型名称
42
+ :param default_target: 默认目标语言代码
43
+ """
44
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ logger.debug(f"使用设备: {self.device}")
46
+
47
  if self.device.type == "cuda":
48
+ logger.debug(f"GPU设备: {torch.cuda.get_device_name(0)}")
49
  total_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
50
+ logger.debug(f"GPU显存: {total_mem:.1f} GB")
51
 
52
+ # 加载模型和分词器
53
+ logger.debug(f"加载模型: {model_name}")
54
  self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
55
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
56
  model_name,
 
58
  ).to(self.device)
59
 
60
  self.default_target = default_target
61
+ logger.debug(f"翻译器初始化完成,默认目标语言: {default_target}")
62
 
63
  def detect_lang_code(self, text: str) -> str:
64
+ """
65
+ 检测文本语言并返回NLLB语言代码
66
+
67
+ :param text: 要检测的文本
68
+ :return: NLLB语言代码
69
+ """
70
  try:
71
  lang = detect(text)
72
+ logger.debug(f"检测到语言: {lang}")
73
  except Exception:
74
+ logger.debug("语言检测失败,默认使用中文(zh)")
75
  lang = "zh-cn"
76
 
77
+ # 语言代码映射
78
  lang_map = {
79
  "zh-cn": "zho_Hans", "zh": "zho_Hans", "en": "eng_Latn", "fr": "fra_Latn",
80
  "de": "deu_Latn", "ja": "jpn_Jpan", "ko": "kor_Hang", "ar": "arb_Arab"
81
  }
82
+
83
  lang_code = lang_map.get(lang.lower(), "eng_Latn")
84
+ logger.debug(f"映射语言代码: {lang} -> {lang_code}")
85
  return lang_code
86
 
87
  def translate(self, text: str, target_lang_code: str = None) -> str:
88
+ """
89
+ 翻译文本到目标语言
90
+
91
+ :param text: 要翻译的文本
92
+ :param target_lang_code: 目标语言代码,如果为None则使用默认目标语言
93
+ :return: 翻译后的文本
94
+ """
95
+ logger.debug("开始翻译")
96
+
97
+ # 记录原文(INFO级别)
98
+ logger.info(f"[翻译原文] {text}")
99
+
100
+ # 检测源语言
101
  src_lang = self.detect_lang_code(text)
102
  tgt_lang = target_lang_code or self.default_target
103
 
104
+ # 准备输入
105
  self.tokenizer.src_lang = src_lang
106
  inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
107
  inputs["forced_bos_token_id"] = self.tokenizer.convert_tokens_to_ids(tgt_lang)
108
 
109
+ # 执行翻译
110
  start = time.time()
111
  with torch.no_grad():
112
  output = self.model.generate(**inputs, max_new_tokens=80)
113
+
114
+ # 解码结果
115
  result = self.tokenizer.decode(output[0], skip_special_tokens=True)
116
+
117
+ # 记录耗时和结果
118
+ duration = time.time() - start
119
+ logger.debug(f"翻译完成: {src_lang} -> {tgt_lang}, 耗时: {duration:.2f}秒")
120
+
121
+ # 记录翻译结果(INFO级别)
122
+ logger.info(f"[翻译结果] {result}")
123
+
124
  return result
125
 
126
 
127
  if __name__ == "__main__":
128
+ # 设置日志级别为DEBUG以查看详细信息
129
+ logger.setLevel(logging.DEBUG)
130
+
131
+ # 创建翻译器
132
  translator = NLLBTranslator()
133
 
134
+ # 测试中文到英文
135
  zh_text = "你会学习到如何使用音频数据集"
136
+ logger.info("\n==== 中文 → 英文 ====")
137
+ result = translator.translate(zh_text, target_lang_code="eng_Latn")
138
+ logger.info(f"测试完成: {result}")
139
 
140
+ # 测试英文到法语
141
  en_text = "This audio processing pipeline is fast and accurate."
142
+ logger.info("\n==== 英文 → 法语 ====")
143
+ result = translator.translate(en_text, target_lang_code="fra_Latn")
144
+ logger.info(f"测试完成: {result}")
145
+
146
+ # 测试英文到阿拉伯语
147
+ logger.info("\n==== 英文 → 阿拉伯语 ====")
148
+ result = translator.translate(en_text, target_lang_code="arb_Arab")
149
+ logger.info(f"测试完成: {result}")
vad/__init__.py CHANGED
@@ -1,3 +1,36 @@
1
  from .vad import AudioVad, AudioSegment
 
 
2
 
3
- __all__ = ['AudioVad', 'AudioSegment']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from .vad import AudioVad, AudioSegment
2
+ import numpy as np
3
+ from typing import List, Tuple
4
 
5
+ class VoiceActivityDetector:
6
+ """
7
+ VAD检测器,用于检测音频中的语音片段
8
+ 这是一个包装类,内部使用AudioVad实现功能
9
+ """
10
+
11
+ def __init__(self, save_audio=True, save_json=True):
12
+ """
13
+ 初始化VAD检测器
14
+
15
+ :param save_audio: 是否保存分段音频
16
+ :param save_json: 是否保存JSON元数据
17
+ """
18
+ self.vad = AudioVad(
19
+ save_audio=save_audio,
20
+ save_json=save_json,
21
+ output_dir="dataset/audio/segments",
22
+ json_dir="dataset/audio/metadata"
23
+ )
24
+
25
+ def detect_voice_segments(self, audio_data: np.ndarray, sample_rate: int) -> List[Tuple[float, float]]:
26
+ """
27
+ 检测音频中的语音片段
28
+
29
+ :param audio_data: 音频数据
30
+ :param sample_rate: 采样率
31
+ :return: 语音片段列表,每个片段为(开始时间, 结束时间)的元组
32
+ """
33
+ segments = self.vad.process_audio_data(audio_data, sample_rate)
34
+ return [(segment.start_time, segment.end_time) for segment in segments]
35
+
36
+ __all__ = ['AudioVad', 'AudioSegment', 'VoiceActivityDetector']