siyuwang541 commited on
Commit
95bd630
·
verified ·
1 Parent(s): bb95d10
Files changed (13) hide show
  1. .gitattributes +2 -0
  2. 1.png +0 -0
  3. README.md +15 -15
  4. __init__.py +0 -0
  5. app.py +1330 -722
  6. app_pro.py +840 -0
  7. audio_127.0.0.1.wav +3 -0
  8. image_127.0.0.1.jpg +0 -0
  9. requirements.txt +8 -4
  10. se_app.py +232 -0
  11. temp_audio.wav +3 -0
  12. todogen_LLM_config.yaml +11 -1
  13. tools.py +828 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ audio_127.0.0.1.wav filter=lfs diff=lfs merge=lfs -text
37
+ temp_audio.wav filter=lfs diff=lfs merge=lfs -text
1.png ADDED
README.md CHANGED
@@ -1,16 +1,16 @@
1
- ---
2
- title: ToDoAgent
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.32.0
8
- app_file: app.py
9
- pinned: false
10
- license: bsd
11
- short_description: AI Agent filters, creates to-do list and reminds smartly
12
- tags: ['agent-demo-track']
13
- demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
14
- ---
15
-
16
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
1
+ ---
2
+ title: ToDoAgent
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bsd
11
+ short_description: AI Agent filters, creates to-do list and reminds smartly
12
+ tags: ['agent-demo-track']
13
+ demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
14
+ ---
15
+
16
  An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
__init__.py ADDED
File without changes
app.py CHANGED
@@ -1,722 +1,1330 @@
1
- import gradio as gr
2
- import json
3
- from pathlib import Path
4
- import yaml
5
- import re
6
- import logging
7
- import io
8
- import sys
9
- import re
10
- from datetime import datetime, timezone, timedelta
11
- import requests
12
-
13
- CONFIG = None
14
- HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
15
-
16
- def load_hf_config():
17
- global CONFIG
18
- if CONFIG is None:
19
- try:
20
- with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
21
- CONFIG = yaml.safe_load(f)
22
- print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
23
- except FileNotFoundError:
24
- print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
25
- CONFIG = {}
26
- except Exception as e:
27
- print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
28
- CONFIG = {}
29
- return CONFIG
30
-
31
- def get_hf_openai_config():
32
- config = load_hf_config()
33
- return config.get('openai', {})
34
-
35
- def get_hf_openai_filter_config():
36
- config = load_hf_config()
37
- return config.get('openai_filter', {})
38
-
39
- def get_hf_paths_config():
40
- config = load_hf_config()
41
- base = Path(__file__).resolve().parent
42
- paths_cfg = config.get('paths', {})
43
- return {
44
- 'base_dir': base,
45
- 'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
46
- 'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
47
- 'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
48
- }
49
-
50
- llm_config = get_hf_openai_config()
51
- NVIDIA_API_BASE_URL = llm_config.get('base_url')
52
- NVIDIA_API_KEY = llm_config.get('api_key')
53
- NVIDIA_MODEL_NAME = llm_config.get('model')
54
-
55
- filter_config = get_hf_openai_filter_config()
56
- Filter_API_BASE_URL = filter_config.get('base_url_filter')
57
- Filter_API_KEY = filter_config.get('api_key_filter')
58
- Filter_MODEL_NAME = filter_config.get('model_filter')
59
-
60
- if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
61
- print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
62
- NVIDIA_API_BASE_URL = ""
63
- NVIDIA_API_KEY = ""
64
- NVIDIA_MODEL_NAME = ""
65
-
66
- if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
67
- print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
68
- Filter_API_BASE_URL = ""
69
- Filter_API_KEY = ""
70
- Filter_MODEL_NAME = ""
71
-
72
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
73
- logger = logging.getLogger(__name__)
74
-
75
- def load_single_few_shot_file_hf(file_path: Path) -> str:
76
- try:
77
- with open(file_path, 'r', encoding='utf-8') as f:
78
- content = f.read()
79
- escaped_content = content.replace('{', '{{').replace('}', '}}')
80
- return escaped_content
81
- except FileNotFoundError:
82
- return ""
83
- except Exception:
84
- return ""
85
-
86
- PROMPT_TEMPLATE_CONTENT = ""
87
- TRUE_POSITIVE_EXAMPLES_CONTENT = ""
88
- FALSE_POSITIVE_EXAMPLES_CONTENT = ""
89
-
90
- def load_prompt_data_hf():
91
- global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
92
- paths = get_hf_paths_config()
93
- try:
94
- with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
95
- PROMPT_TEMPLATE_CONTENT = f.read()
96
- except FileNotFoundError:
97
- PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
98
-
99
- TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
100
- FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
101
-
102
- load_prompt_data_hf()
103
-
104
- def _process_parsed_json(parsed_data):
105
- try:
106
- if isinstance(parsed_data, list):
107
- if not parsed_data:
108
- return [{}]
109
-
110
- processed_list = []
111
- for item in parsed_data:
112
- if isinstance(item, dict):
113
- processed_list.append(item)
114
- else:
115
- try:
116
- processed_list.append({"content": str(item)})
117
- except:
118
- processed_list.append({"content": "无法转换的项目"})
119
-
120
- if not processed_list:
121
- return [{}]
122
-
123
- return processed_list
124
-
125
- elif isinstance(parsed_data, dict):
126
- return parsed_data
127
-
128
- else:
129
- return {"content": str(parsed_data)}
130
-
131
- except Exception as e:
132
- return {"error": f"Error processing parsed JSON: {e}"}
133
-
134
- def json_parser(text: str) -> dict:
135
- try:
136
- try:
137
- parsed_data = json.loads(text)
138
- return _process_parsed_json(parsed_data)
139
- except json.JSONDecodeError:
140
- pass
141
-
142
- match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
143
- if match:
144
- json_str = match.group(1).strip()
145
- json_str = re.sub(r',\s*]', ']', json_str)
146
- json_str = re.sub(r',\s*}', '}', json_str)
147
- try:
148
- parsed_data = json.loads(json_str)
149
- return _process_parsed_json(parsed_data)
150
- except json.JSONDecodeError:
151
- pass
152
-
153
- array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
154
- if array_match:
155
- potential_json = array_match.group(0).strip()
156
- try:
157
- parsed_data = json.loads(potential_json)
158
- return _process_parsed_json(parsed_data)
159
- except json.JSONDecodeError:
160
- pass
161
-
162
- object_match = re.search(r'\{.*?\}', text, re.DOTALL)
163
- if object_match:
164
- potential_json = object_match.group(0).strip()
165
- try:
166
- parsed_data = json.loads(potential_json)
167
- return _process_parsed_json(parsed_data)
168
- except json.JSONDecodeError:
169
- pass
170
-
171
- return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
172
-
173
- except Exception as e:
174
- return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
175
-
176
- def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
177
- mock_data = [(text_input, message_id)]
178
-
179
- system_prompt = """
180
- # 角色
181
- 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
182
-
183
- # 任务
184
- 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
185
- 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
186
-
187
- # 要求
188
- 1. 以json格式输出
189
- 2. content简洁提炼关键词,字符数<20以内
190
- 3. 输入条数和输出条数完全一样
191
-
192
- # 输出示例
193
- ```
194
- [
195
- {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
196
- {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
197
- ]
198
- ```
199
- """
200
-
201
- llm_messages = [
202
- {"role": "system", "content": system_prompt},
203
- {"role": "user", "content": str(mock_data)}
204
- ]
205
-
206
- try:
207
- if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
208
- return [{"error": "Filter API configuration incomplete", "-": "-"}]
209
-
210
- headers = {
211
- "Authorization": f"Bearer {Filter_API_KEY}",
212
- "Accept": "application/json"
213
- }
214
- payload = {
215
- "model": Filter_MODEL_NAME,
216
- "messages": llm_messages,
217
- "temperature": 0.0,
218
- "top_p": 0.95,
219
- "max_tokens": 1024,
220
- "stream": False
221
- }
222
-
223
- api_url = f"{Filter_API_BASE_URL}/chat/completions"
224
-
225
- try:
226
- response = requests.post(api_url, headers=headers, json=payload)
227
- response.raise_for_status()
228
- raw_llm_response = response.json()["choices"][0]["message"]["content"]
229
- except requests.exceptions.RequestException as e:
230
- return [{"error": f"Filter API call failed: {e}", "-": "-"}]
231
-
232
- raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
233
- parsed_filter_data = json_parser(raw_llm_response)
234
-
235
- if "error" in parsed_filter_data:
236
- return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
237
-
238
- if isinstance(parsed_filter_data, list) and parsed_filter_data:
239
- for item in parsed_filter_data:
240
- if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
241
- item["分类"] = "其他"
242
-
243
- request_id_list = {message_id}
244
- response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
245
- diff = request_id_list - response_id_list
246
-
247
- if diff:
248
- for missed_id in diff:
249
- parsed_filter_data.append({
250
- "message_id": missed_id,
251
- "content": text_input[:20],
252
- "物流取件": 0,
253
- "欠费缴纳": 0,
254
- "待付(还)款": 0,
255
- "会议邀约": 0,
256
- "其他": 100,
257
- "分类": "其他"
258
- })
259
-
260
- return parsed_filter_data
261
- else:
262
- return [{
263
- "message_id": message_id,
264
- "content": text_input[:20],
265
- "物流取件": 0,
266
- "欠费缴纳": 0,
267
- "待付(还)款": 0,
268
- "会议邀约": 0,
269
- "其他": 100,
270
- "分类": "其他",
271
- "error": "Filter LLM returned empty or unexpected format"
272
- }]
273
-
274
- except Exception as e:
275
- return [{
276
- "message_id": message_id,
277
- "content": text_input[:20],
278
- "物流取件": 0,
279
- "欠费缴纳": 0,
280
- "待付(还)款": 0,
281
- "会议邀约": 0,
282
- "其他": 100,
283
- "分类": "其他",
284
- "error": f"Filter LLM call/parse error: {str(e)}"
285
- }]
286
-
287
- def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
288
- if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
289
- return [["error", "Prompt template not loaded", "-"]]
290
-
291
- current_time_iso = datetime.now(timezone.utc).isoformat()
292
- content_escaped = text_input.replace('{', '{{').replace('}', '}}')
293
-
294
- formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
295
- true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
296
- false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
297
- current_time=current_time_iso,
298
- message_id=message_id,
299
- content_escaped=content_escaped
300
- )
301
-
302
- enhanced_prompt = formatted_prompt + """
303
-
304
- # 重要提示
305
- 请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
306
- 你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
307
- """
308
-
309
- llm_messages = [
310
- {"role": "user", "content": enhanced_prompt}
311
- ]
312
-
313
- try:
314
- if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
315
- todo_item = {
316
- message_id: {
317
- "is_todo": True,
318
- "end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
319
- "location": "线上:中国移动APP",
320
- "todo_content": "缴纳话费",
321
- "urgency": "important"
322
- }
323
- }
324
-
325
- todo_content = "缴纳话费"
326
- end_time = todo_item[message_id]["end_time"].split("T")[0]
327
- location = todo_item[message_id]["location"]
328
-
329
- combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
330
-
331
- output_for_df = []
332
- output_for_df.append([1, combined_content, "重要"])
333
-
334
- return output_for_df
335
-
336
- elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
337
- meeting_time = None
338
- meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
339
- meeting_match = re.search(meeting_pattern, text_input)
340
-
341
- if meeting_match:
342
- meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
343
- else:
344
- meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
345
-
346
- todo_item = {
347
- message_id: {
348
- "is_todo": True,
349
- "end_time": meeting_time,
350
- "location": "线上:会议软件",
351
- "todo_content": "参加会议",
352
- "urgency": "important"
353
- }
354
- }
355
-
356
- todo_content = "参加会议"
357
- end_time = todo_item[message_id]["end_time"].split("T")[0]
358
- location = todo_item[message_id]["location"]
359
-
360
- combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
361
-
362
- output_for_df = []
363
- output_for_df.append([1, combined_content, "重要"])
364
-
365
- return output_for_df
366
-
367
- elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
368
- pickup_code = None
369
- code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
370
- code_match = re.search(code_pattern, text_input)
371
-
372
- todo_content = "取快递"
373
- if code_match:
374
- pickup_code = code_match.group(1)
375
- todo_content = f"取快递(取件码:{pickup_code})"
376
-
377
- todo_item = {
378
- message_id: {
379
- "is_todo": True,
380
- "end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
381
- "location": "线下:快递柜",
382
- "todo_content": todo_content,
383
- "urgency": "important"
384
- }
385
- }
386
-
387
- end_time = todo_item[message_id]["end_time"].split("T")[0]
388
- location = todo_item[message_id]["location"]
389
-
390
- combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
391
-
392
- output_for_df = []
393
- output_for_df.append([1, combined_content, "重要"])
394
-
395
- return output_for_df
396
-
397
- if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
398
- return [["error", "Filter API configuration incomplete", "-"]]
399
-
400
- headers = {
401
- "Authorization": f"Bearer {Filter_API_KEY}",
402
- "Accept": "application/json"
403
- }
404
- payload = {
405
- "model": Filter_MODEL_NAME,
406
- "messages": llm_messages,
407
- "temperature": 0.2,
408
- "top_p": 0.95,
409
- "max_tokens": 1024,
410
- "stream": False
411
- }
412
-
413
- api_url = f"{Filter_API_BASE_URL}/chat/completions"
414
-
415
- try:
416
- response = requests.post(api_url, headers=headers, json=payload)
417
- response.raise_for_status()
418
- raw_llm_response = response.json()['choices'][0]['message']['content']
419
- except requests.exceptions.RequestException as e:
420
- return [["error", f"Filter API call failed: {e}", "-"]]
421
-
422
- parsed_todos_data = json_parser(raw_llm_response)
423
-
424
- if "error" in parsed_todos_data:
425
- return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
426
-
427
- output_for_df = []
428
-
429
- if isinstance(parsed_todos_data, dict):
430
- todo_info = None
431
- for key, value in parsed_todos_data.items():
432
- if key == message_id or key == str(message_id):
433
- todo_info = value
434
- break
435
-
436
- if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
437
- todo_content = todo_info.get("todo_content", "未指定待办内容")
438
- end_time = todo_info.get("end_time")
439
- location = todo_info.get("location")
440
- urgency = todo_info.get("urgency", "unimportant")
441
-
442
- combined_content = todo_content
443
-
444
- if end_time and end_time != "null":
445
- try:
446
- date_part = end_time.split("T")[0] if "T" in end_time else end_time
447
- combined_content += f" (截止时间: {date_part}"
448
- except:
449
- combined_content += f" (截止时间: {end_time}"
450
- else:
451
- combined_content += " ("
452
-
453
- if location and location != "null":
454
- combined_content += f", 地点: {location})"
455
- else:
456
- combined_content += ")"
457
-
458
- urgency_display = "一般"
459
- if urgency == "urgent":
460
- urgency_display = "紧急"
461
- elif urgency == "important":
462
- urgency_display = "重要"
463
-
464
- output_for_df = []
465
- output_for_df.append([1, combined_content, urgency_display])
466
- else:
467
- output_for_df = []
468
- output_for_df.append([1, "此消息不包含待办事项", "-"])
469
-
470
- elif isinstance(parsed_todos_data, list):
471
- output_for_df = []
472
-
473
- if not parsed_todos_data:
474
- return [[1, "未能生成待办事项", "-"]]
475
-
476
- for i, item in enumerate(parsed_todos_data):
477
- if isinstance(item, dict):
478
- todo_content = item.get('todo_content', item.get('content', 'N/A'))
479
- status = item.get('status', '未完成')
480
- urgency = item.get('urgency', 'normal')
481
-
482
- combined_content = todo_content
483
-
484
- if 'end_time' in item and item['end_time']:
485
- try:
486
- if isinstance(item['end_time'], str):
487
- date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
488
- combined_content += f" (截止时间: {date_part}"
489
- else:
490
- combined_content += f" (截止时间: {str(item['end_time'])}"
491
- except Exception:
492
- combined_content += " ("
493
- else:
494
- combined_content += " ("
495
-
496
- if 'location' in item and item['location']:
497
- combined_content += f", 地点: {item['location']})"
498
- else:
499
- combined_content += ")"
500
-
501
- importance = "一般"
502
- if urgency == "urgent":
503
- importance = "紧急"
504
- elif urgency == "important":
505
- importance = "重要"
506
-
507
- output_for_df.append([i + 1, combined_content, importance])
508
- else:
509
- try:
510
- item_str = str(item) if item is not None else "未知项目"
511
- output_for_df.append([i + 1, item_str, "一般"])
512
- except Exception:
513
- output_for_df.append([i + 1, "处理错误的项目", "一般"])
514
-
515
- if not output_for_df:
516
- return [["info", "未发现待办事项", "-"]]
517
-
518
- return output_for_df
519
-
520
- except Exception as e:
521
- return [["error", f"LLM call/parse error: {str(e)}", "-"]]
522
-
523
- def process(audio, image):
524
- if audio is not None:
525
- sample_rate, audio_data = audio
526
- audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
527
- else:
528
- audio_info = "未收到音频"
529
-
530
- if image is not None:
531
- image_info = f"图片尺寸: {image.shape}"
532
- else:
533
- image_info = "未收到图片"
534
-
535
- return audio_info, image_info
536
-
537
- def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image):
538
- chat_messages = [{"role": "system", "content": system_message}]
539
- for val in history:
540
- if val[0]:
541
- chat_messages.append({"role": "user", "content": val[0]})
542
- if val[1]:
543
- chat_messages.append({"role": "assistant", "content": val[1]})
544
- chat_messages.append({"role": "user", "content": message})
545
-
546
- chat_response_stream = ""
547
- if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
548
- yield "Filter API 配置不完整,无法提供聊天回复。", []
549
- return
550
-
551
- headers = {
552
- "Authorization": f"Bearer {Filter_API_KEY}",
553
- "Accept": "application/json"
554
- }
555
- payload = {
556
- "model": Filter_MODEL_NAME,
557
- "messages": chat_messages,
558
- "temperature": temperature,
559
- "top_p": top_p,
560
- "max_tokens": max_tokens,
561
- "stream": True
562
- }
563
- api_url = f"{Filter_API_BASE_URL}/chat/completions"
564
-
565
- try:
566
- response = requests.post(api_url, headers=headers, json=payload, stream=True)
567
- response.raise_for_status()
568
-
569
- for chunk in response.iter_content(chunk_size=None):
570
- if chunk:
571
- try:
572
- for line in chunk.decode('utf-8').splitlines():
573
- if line.startswith('data: '):
574
- json_data = line[len('data: '):]
575
- if json_data.strip() == '[DONE]':
576
- break
577
- data = json.loads(json_data)
578
- token = data['choices'][0]['delta'].get('content', '')
579
- if token:
580
- chat_response_stream += token
581
- yield chat_response_stream, []
582
- except json.JSONDecodeError:
583
- pass
584
- except Exception as e:
585
- yield chat_response_stream + f"\n\n错误: {e}", []
586
-
587
- except requests.exceptions.RequestException as e:
588
- yield f"调用 NVIDIA API 失败: {e}", []
589
-
590
- with gr.Blocks() as app:
591
- gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
592
-
593
- with gr.Row():
594
- with gr.Column(scale=2):
595
- gr.Markdown("## Chat Interface")
596
- chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages")
597
- msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
598
-
599
- with gr.Row():
600
- audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
601
- image_input = gr.Image(label="上传图片", type="numpy")
602
-
603
- with gr.Accordion("高级设置", open=False):
604
- system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
605
- max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)")
606
- temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
607
- top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
608
-
609
- with gr.Row():
610
- submit_btn = gr.Button("发送", variant="primary")
611
- clear_btn = gr.Button("清除聊天和ToDo")
612
-
613
- with gr.Column(scale=1):
614
- gr.Markdown("## Generated ToDo List")
615
- todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
616
- datatype=["number", "str", "str"],
617
- row_count=(0, "dynamic"),
618
- col_count=(3, "fixed"),
619
- label="待办事项列表")
620
-
621
- def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f):
622
- if not ch_history: ch_history = []
623
- ch_history.append({"role": "user", "content": user_msg_content})
624
- yield ch_history, []
625
-
626
- formatted_hist_for_respond = []
627
- temp_user_msg_for_hist = None
628
- for item_hist in ch_history[:-1]:
629
- if item_hist["role"] == "user":
630
- temp_user_msg_for_hist = item_hist["content"]
631
- elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
632
- formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
633
- temp_user_msg_for_hist = None
634
- elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
635
- formatted_hist_for_respond.append(("", item_hist["content"]))
636
-
637
- ch_history.append({"role": "assistant", "content": ""})
638
-
639
- full_bot_response = ""
640
- for bot_response_token, _ in respond(user_msg_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
641
- full_bot_response = bot_response_token
642
- ch_history[-1]["content"] = full_bot_response
643
- yield ch_history, []
644
-
645
- text_for_todo = user_msg_content
646
- current_todos_list = []
647
-
648
- filtered_result = filter_message_with_llm(text_for_todo)
649
-
650
- if isinstance(filtered_result, dict) and "error" in filtered_result:
651
- current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
652
- elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
653
- current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
654
- elif isinstance(filtered_result, list):
655
- category = None
656
-
657
- if not filtered_result:
658
- if text_for_todo:
659
- msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
660
- current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
661
- yield ch_history, current_todos_list
662
- return
663
-
664
- valid_item = None
665
- for item in filtered_result:
666
- if isinstance(item, dict):
667
- valid_item = item
668
- if "分类" in item:
669
- category = item["分类"]
670
- break
671
-
672
- if valid_item is None:
673
- if text_for_todo:
674
- msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
675
- current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
676
- yield ch_history, current_todos_list
677
- return
678
-
679
- if category == "其他":
680
- current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
681
- else:
682
- if text_for_todo:
683
- msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
684
- current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
685
- else:
686
- if text_for_todo:
687
- msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
688
- current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
689
-
690
- yield ch_history, current_todos_list
691
-
692
- submit_btn.click(
693
- handle_submit,
694
- [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
695
- [chatbot, todolist_df]
696
- )
697
- msg.submit(
698
- handle_submit,
699
- [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
700
- [chatbot, todolist_df]
701
- )
702
-
703
- def clear_all():
704
- return None, None, ""
705
- clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
706
-
707
- with gr.Tab("Audio/Image Processing (Original)"):
708
- gr.Markdown("## 处理音频和图片")
709
- audio_processor = gr.Audio(label="上传音频", type="numpy")
710
- image_processor = gr.Image(label="上传图片", type="numpy")
711
- process_btn = gr.Button("处理", variant="primary")
712
- audio_output = gr.Textbox(label="音频信息")
713
- image_output = gr.Textbox(label="图片信息")
714
-
715
- process_btn.click(
716
- process,
717
- inputs=[audio_processor, image_processor],
718
- outputs=[audio_output, image_output]
719
- )
720
-
721
- if __name__ == "__main__":
722
- app.launch(debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from pathlib import Path
4
+ import yaml
5
+ import re
6
+ import logging
7
+ import io
8
+ import sys
9
+ import os
10
+ import re
11
+ from datetime import datetime, timezone, timedelta
12
+ import requests
13
+
14
+ from tools import FileUploader, ResultExtractor, audio_to_str, image_to_str, azure_speech_to_text #gege的多模态
15
+ import numpy as np
16
+ from scipy.io.wavfile import write as write_wav
17
+ from PIL import Image
18
+
19
+ # 指定保存文件的相对路径
20
+ SAVE_DIR = 'download' # 相对路径
21
+ os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
22
+
23
+ def save_audio(audio, filename):
24
+ """保存音频为.wav文件"""
25
+ sample_rate, audio_data = audio
26
+ write_wav(filename, sample_rate, audio_data)
27
+
28
+ def save_image(image, filename):
29
+ """保存图片为.jpg文件"""
30
+ img = Image.fromarray(image.astype('uint8'))
31
+ img.save(filename)
32
+
33
+ # --- IP获取功能 ( se_app.py 迁移) ---
34
+ def get_client_ip(request: gr.Request, debug_mode=False):
35
+ """获取客户端真实IP地址"""
36
+ if request:
37
+ # 从请求头中获取真实IP(考虑代理情况)
38
+ x_forwarded_for = request.headers.get("x-forwarded-for", "")
39
+ if x_forwarded_for:
40
+ client_ip = x_forwarded_for.split(",")[0]
41
+ else:
42
+ client_ip = request.client.host
43
+ if debug_mode:
44
+ print(f"Debug: Client IP detected as {client_ip}")
45
+ return client_ip
46
+ return "unknown"
47
+
48
+ # --- 配置加载 (从 config_loader.py 迁移并简化) ---
49
+ CONFIG = None
50
+ HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
51
+
52
+ def load_hf_config():
53
+ global CONFIG
54
+ if CONFIG is None:
55
+ try:
56
+ with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
57
+ CONFIG = yaml.safe_load(f)
58
+ print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
59
+ except FileNotFoundError:
60
+ print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
61
+ CONFIG = {} # 提供一个空配置以避免后续错误
62
+ except Exception as e:
63
+ print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
64
+ CONFIG = {}
65
+ return CONFIG
66
+
67
+ def get_hf_openai_config():
68
+ config = load_hf_config()
69
+ return config.get('openai', {})
70
+
71
+ def get_hf_openai_filter_config():
72
+ config = load_hf_config()
73
+ return config.get('openai_filter', {})
74
+
75
+ def get_hf_xunfei_config():
76
+ config = load_hf_config()
77
+ return config.get('xunfei', {})
78
+
79
+ def get_hf_azure_speech_config():
80
+ config = load_hf_config()
81
+ return config.get('azure_speech', {})
82
+
83
+ def get_hf_paths_config():
84
+ config = load_hf_config()
85
+ # 在hf环境下,路径相对于hf目录
86
+ base = Path(__file__).resolve().parent
87
+ paths_cfg = config.get('paths', {})
88
+ return {
89
+ 'base_dir': base,
90
+ 'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
91
+ 'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
92
+ 'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
93
+ # data_dir 和 logging_dir 在 app.py 中可能用途不大,除非需要保存 LLM 输出
94
+ }
95
+
96
+ # --- LLM Client 初始化 (使用 NVIDIA API) ---
97
+ # 从配置加载 NVIDIA API base_url, api_key 和 model
98
+ llm_config = get_hf_openai_config()
99
+ NVIDIA_API_BASE_URL = llm_config.get('base_url')
100
+ NVIDIA_API_KEY = llm_config.get('api_key')
101
+ NVIDIA_MODEL_NAME = llm_config.get('model')
102
+
103
+ # 从配置加载 Filter API 的 base_url, api_key 和 model
104
+ filter_config = get_hf_openai_filter_config()
105
+ Filter_API_BASE_URL = filter_config.get('base_url_filter')
106
+ Filter_API_KEY = filter_config.get('api_key_filter')
107
+ Filter_MODEL_NAME = filter_config.get('model_filter')
108
+
109
+
110
+ if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
111
+ print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
112
+ # 提供默认值或退出,以便程序可以继续运行,但LLM调用会失败
113
+ NVIDIA_API_BASE_URL = ""
114
+ NVIDIA_API_KEY = ""
115
+ NVIDIA_MODEL_NAME = ""
116
+
117
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
118
+ print(" 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
119
+ # 提供默认值或退出,以便程序可以继续运行,但Filter LLM调用会失败
120
+ Filter_API_BASE_URL = ""
121
+ Filter_API_KEY = ""
122
+ Filter_MODEL_NAME = ""
123
+
124
+ # --- 日志配置 (简化版) ---
125
+ # 修正后的标准流编码设置 (如果需要,但 Gradio 通常处理自己的输出)
126
+ # sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
127
+ # sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True)
128
+ # sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', write_through=True)
129
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
130
+ logger = logging.getLogger(__name__)
131
+
132
+ # --- Prompt Few-Shot 加载 (从 todogen_llm.py 迁移并适配) ---
133
+ def load_single_few_shot_file_hf(file_path: Path) -> str:
134
+ """加载单个 few-shot 文件并转义 { 和 }"""
135
+ try:
136
+ with open(file_path, 'r', encoding='utf-8') as f:
137
+ content = f.read()
138
+ escaped_content = content.replace('{', '{{').replace('}', '}}')
139
+ logger.info(f"✅ 成功加载并转义文件: {file_path}")
140
+ return escaped_content
141
+ except FileNotFoundError:
142
+ logger.warning(f"⚠️ 警告:找不到文件 {file_path}。")
143
+ return ""
144
+ except Exception as e:
145
+ logger.error(f"❌ 加载文件 {file_path} 时出错: {e}", exc_info=True)
146
+ return ""
147
+
148
+ PROMPT_TEMPLATE_CONTENT = ""
149
+ TRUE_POSITIVE_EXAMPLES_CONTENT = ""
150
+ FALSE_POSITIVE_EXAMPLES_CONTENT = ""
151
+
152
+ def load_prompt_data_hf():
153
+ global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
154
+ paths = get_hf_paths_config()
155
+ try:
156
+ with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
157
+ PROMPT_TEMPLATE_CONTENT = f.read()
158
+ logger.info(f"✅ 成功加载 Prompt 模板文件: {paths['prompt_template']}")
159
+ except FileNotFoundError:
160
+ logger.error(f"❌ 错误:找不到 Prompt 模板文件:{paths['prompt_template']}")
161
+ PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
162
+
163
+ TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
164
+ FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
165
+
166
+ # 应用启动时加载 prompts
167
+ load_prompt_data_hf()
168
+
169
+ # --- JSON 解析器 (从 todogen_llm.py 迁移) ---
170
+ def json_parser(text: str) -> dict:
171
+ # 改进的JSON解析器,更健壮地处理各种格式
172
+ logger.info(f"Attempting to parse: {text[:200]}...")
173
+ try:
174
+ # 1. 尝试直接将整个文本作为JSON解析
175
+ try:
176
+ parsed_data = json.loads(text)
177
+ # 使用_process_parsed_json处理解析结果
178
+ return _process_parsed_json(parsed_data)
179
+ except json.JSONDecodeError:
180
+ pass # 如果直接解析失败,继续尝试提取代码块
181
+
182
+ # 2. 尝试从 ```json ... ``` 代码块中提取和解析
183
+ match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
184
+ if match:
185
+ json_str = match.group(1).strip()
186
+ # 修复常见的JSON格式问题
187
+ json_str = re.sub(r',\s*]', ']', json_str)
188
+ json_str = re.sub(r',\s*}', '}', json_str)
189
+ try:
190
+ parsed_data = json.loads(json_str)
191
+ # 使用_process_parsed_json处理解析结果
192
+ return _process_parsed_json(parsed_data)
193
+ except json.JSONDecodeError as e_block:
194
+ logger.warning(f"JSONDecodeError from code block: {e_block} while parsing: {json_str[:200]}")
195
+ # 如果从代码块解析也失败,则继续
196
+
197
+ # 3. 尝试查找最外层的 '{...}' 或 '[...]' 作为JSON
198
+ # 先尝试查找数组格式 [...]
199
+ array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
200
+ if array_match:
201
+ potential_json = array_match.group(0).strip()
202
+ try:
203
+ parsed_data = json.loads(potential_json)
204
+ # 使用_process_parsed_json处理解析结果
205
+ return _process_parsed_json(parsed_data)
206
+ except json.JSONDecodeError:
207
+ logger.warning(f"Could not parse potential JSON array: {potential_json[:200]}")
208
+ pass
209
+
210
+ # 再尝试查找单个对象格式 {...}
211
+ object_match = re.search(r'\{.*?\}', text, re.DOTALL)
212
+ if object_match:
213
+ potential_json = object_match.group(0).strip()
214
+ try:
215
+ parsed_data = json.loads(potential_json)
216
+ # 使用_process_parsed_json处理解析结果
217
+ return _process_parsed_json(parsed_data)
218
+ except json.JSONDecodeError:
219
+ logger.warning(f"Could not parse potential JSON object: {potential_json[:200]}")
220
+ pass
221
+
222
+ # 4. 如果所有尝试都失败,返回错误信息
223
+ logger.error(f"Failed to find or parse JSON block in text: {text[:500]}") # 增加日志长度
224
+ return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
225
+
226
+ except Exception as e: # 捕获所有其他意外错误
227
+ logger.error(f"Unexpected error in json_parser: {e} for text: {text[:200]}", exc_info=True)
228
+ return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
229
+
230
+ def _process_parsed_json(parsed_data):
231
+ """处理解析后的JSON数据,确保返回有效的数据结构"""
232
+ try:
233
+ # 如果解析结果是空列表,���回包含空字典的列表
234
+ if isinstance(parsed_data, list):
235
+ if not parsed_data:
236
+ logger.warning("JSON解析结果为空列表,返回包含空字典的列表")
237
+ return [{}]
238
+
239
+ # 确保列表中的每个元素都是字典
240
+ processed_list = []
241
+ for item in parsed_data:
242
+ if isinstance(item, dict):
243
+ processed_list.append(item)
244
+ else:
245
+ # 如果不是字典,将其转换为字典
246
+ try:
247
+ processed_list.append({"content": str(item)})
248
+ except:
249
+ processed_list.append({"content": "无法转换的项目"})
250
+
251
+ # 如果处理后的列表为空,返回包含空字典的列表
252
+ if not processed_list:
253
+ logger.warning("处理后的JSON列表为空,返回包含空字典的列表")
254
+ return [{}]
255
+
256
+ return processed_list
257
+
258
+ # 如果是字典,直接返回
259
+ elif isinstance(parsed_data, dict):
260
+ return parsed_data
261
+
262
+ # 如果是其他类型,转换为字典
263
+ else:
264
+ logger.warning(f"JSON解析结果不是列表或字典,而是{type(parsed_data)},转换为字典")
265
+ return {"content": str(parsed_data)}
266
+
267
+ except Exception as e:
268
+ logger.error(f"处理解析后的JSON数据时出错: {e}")
269
+ return {"error": f"Error processing parsed JSON: {e}"}
270
+
271
+ # --- Filter 模块的 System Prompt (从 filter_message/libs.py 迁移) ---
272
+ FILTER_SYSTEM_PROMPT = """
273
+ # 角色
274
+ 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
275
+
276
+ # 任务
277
+ 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
278
+ 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
279
+
280
+ # 要求
281
+ 1. 以json格式输出
282
+ 2. content简洁提炼关键词,字符数<20以内
283
+ 3. 输入条数和输出条数完全一样
284
+
285
+ # 输出示例
286
+ ```
287
+ [
288
+ {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
289
+ {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议"}
290
+ ]
291
+
292
+ ```
293
+ """
294
+
295
+ # --- Filter 核心逻辑 (从ToDoAgent集成) ---
296
+ def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
297
+ logger.info(f"调用 filter_message_with_llm 处理输入: {text_input} (msg_id: {message_id})")
298
+
299
+ # 构造发送给 LLM 的消息
300
+ # filter 模块的 send_llm_with_prompt 接收的是 tuple[tuple] 格式的数据
301
+ # 这里我们只有一个文本输入,需要模拟成那种格式
302
+ mock_data = [(text_input, message_id)]
303
+
304
+ # 使用与ToDoAgent相同的system prompt
305
+ system_prompt = """
306
+ # 角色
307
+ 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
308
+
309
+ # 任务
310
+ 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
311
+ 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
312
+
313
+ # 要求
314
+ 1. 以json格式输出
315
+ 2. content简洁提炼关键词,字符数<20以内
316
+ 3. 输入条数和输出条数完全一样
317
+
318
+ # 输出示例
319
+ ```
320
+ [
321
+ {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
322
+ {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
323
+ ]
324
+ ```
325
+ """
326
+
327
+ llm_messages = [
328
+ {"role": "system", "content": system_prompt},
329
+ {"role": "user", "content": str(mock_data)}
330
+ ]
331
+
332
+ try:
333
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
334
+ logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
335
+ return [{"error": "Filter API configuration incomplete", "-": "-"}]
336
+
337
+ headers = {
338
+ "Authorization": f"Bearer {Filter_API_KEY}",
339
+ "Accept": "application/json"
340
+ }
341
+ payload = {
342
+ "model": Filter_MODEL_NAME,
343
+ "messages": llm_messages,
344
+ "temperature": 0.0, # 为提高准确率,温度为0(与ToDoAgent一致)
345
+ "top_p": 0.95,
346
+ "max_tokens": 1024,
347
+ "stream": False
348
+ }
349
+
350
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
351
+
352
+ try:
353
+ response = requests.post(api_url, headers=headers, json=payload)
354
+ response.raise_for_status() # 检查 HTTP 错误
355
+ raw_llm_response = response.json()["choices"][0]["message"]["content"]
356
+ logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
357
+ except requests.exceptions.RequestException as e:
358
+ logger.error(f"调用 Filter API 失败: {e}")
359
+ return [{"error": f"Filter API call failed: {e}", "-": "-"}]
360
+ logger.info(f"Filter LLM 原始回复 (部分): {raw_llm_response[:200]}...")
361
+
362
+ # 解析 LLM 响应
363
+ # 移除可能的代码块标记
364
+ raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
365
+ parsed_filter_data = json_parser(raw_llm_response)
366
+
367
+ if "error" in parsed_filter_data:
368
+ logger.error(f"解析 Filter LLM 响应失败: {parsed_filter_data['error']}")
369
+ return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
370
+
371
+ # 返回解析后的数据
372
+ if isinstance(parsed_filter_data, list) and parsed_filter_data:
373
+ # 应用规则:如果分类是欠费缴纳且内容包含"缴费支出",归类为"其他"
374
+ for item in parsed_filter_data:
375
+ if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
376
+ item["分类"] = "其他"
377
+
378
+ # 检查是否有遗漏的消息ID(ToDoAgent的补充逻辑)
379
+ request_id_list = {message_id}
380
+ response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
381
+ diff = request_id_list - response_id_list
382
+
383
+ if diff:
384
+ logger.warning(f"Filter LLM 响应中有遗漏的消息ID: {diff}")
385
+ # 对于遗漏的消息,添加一个默认分类为"其他"的项
386
+ for missed_id in diff:
387
+ parsed_filter_data.append({
388
+ "message_id": missed_id,
389
+ "content": text_input[:20], # 截取前20个字符作为content
390
+ "物流取件": 0,
391
+ "欠费缴纳": 0,
392
+ "待付(还)款": 0,
393
+ "会议邀约": 0,
394
+ "其他": 100,
395
+ "分类": "其他"
396
+ })
397
+
398
+ return parsed_filter_data
399
+ else:
400
+ logger.warning(f"Filter LLM 返回空列表或非预期格式: {parsed_filter_data}")
401
+ # 返回默认分类为"其他"的项
402
+ return [{
403
+ "message_id": message_id,
404
+ "content": text_input[:20], # 截取前20个字符作为content
405
+ "物流取件": 0,
406
+ "欠费缴纳": 0,
407
+ "待付(还)款": 0,
408
+ "会议邀约": 0,
409
+ "其他": 100,
410
+ "分类": "其他",
411
+ "error": "Filter LLM returned empty or unexpected format"
412
+ }]
413
+
414
+ except Exception as e:
415
+ logger.exception(f"调用 Filter LLM 或解析时发生错误 (filter_message_with_llm)")
416
+ return [{
417
+ "message_id": message_id,
418
+ "content": text_input[:20], # 截取前20个字符作为content
419
+ "物流取件": 0,
420
+ "欠费缴纳": 0,
421
+ "待付(还)款": 0,
422
+ "会议邀约": 0,
423
+ "其他": 100,
424
+ "分类": "其他",
425
+ "error": f"Filter LLM call/parse error: {str(e)}"
426
+ }]
427
+
428
+ # --- ToDo List 生成核心逻辑 (使用迁移的代码) ---
429
+ def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
430
+ """根据输入文本生成 ToDoList (使用迁移的逻辑)"""
431
+ logger.info(f"调用 generate_todolist_from_text 处理输入: {text_input} (msg_id: {message_id})")
432
+
433
+ if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
434
+ logger.error("Prompt 模板未正确加载,无法生成 ToDoList。")
435
+ return [["error", "Prompt template not loaded", "-"]]
436
+
437
+ current_time_iso = datetime.now(timezone.utc).isoformat()
438
+ # 转义输入内容中的 { 和 }
439
+ content_escaped = text_input.replace('{', '{{').replace('}', '}}')
440
+
441
+ # 构造 prompt
442
+ formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
443
+ true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
444
+ false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
445
+ current_time=current_time_iso,
446
+ message_id=message_id,
447
+ content_escaped=content_escaped
448
+ )
449
+
450
+ # 添加明确的JSON输出指令
451
+ enhanced_prompt = formatted_prompt + """
452
+
453
+ # 重要提示
454
+ 请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
455
+ 你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
456
+ """
457
+
458
+ # 构造发送给 LLM 的消息
459
+ llm_messages = [
460
+ {"role": "user", "content": enhanced_prompt}
461
+ ]
462
+
463
+ logger.info(f"发送给 LLM 的消息 (部分): {str(llm_messages)[:300]}...")
464
+
465
+ try:
466
+ # 根据输入文本智能生成 ToDo List
467
+ # 如果是移动话费充值提醒类消息
468
+ if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
469
+ # 直接生成待办事项,不调用API
470
+ todo_item = {
471
+ message_id: {
472
+ "is_todo": True,
473
+ "end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
474
+ "location": "线上:中国移动APP",
475
+ "todo_content": "缴纳话费",
476
+ "urgency": "important"
477
+ }
478
+ }
479
+
480
+ # 转换为表格显示格式 - 合并为一行
481
+ todo_content = "缴纳话费"
482
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
483
+ location = todo_item[message_id]["location"]
484
+
485
+ # 合并所有信息到任务内容中
486
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
487
+
488
+ output_for_df = []
489
+ output_for_df.append([1, combined_content, "重要"])
490
+
491
+ return output_for_df
492
+
493
+ # 如果是会议邀约类消息
494
+ elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
495
+ # 提取可能的会议时间
496
+ meeting_time = None
497
+ meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
498
+ meeting_match = re.search(meeting_pattern, text_input)
499
+
500
+ if meeting_match:
501
+ # 简单处理,实际应用中应该更精确地解析日期时间
502
+ meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
503
+ else:
504
+ meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
505
+
506
+ todo_item = {
507
+ message_id: {
508
+ "is_todo": True,
509
+ "end_time": meeting_time,
510
+ "location": "线上:会议软件",
511
+ "todo_content": "参加会议",
512
+ "urgency": "important"
513
+ }
514
+ }
515
+
516
+ # 转换为表格显示格式 - 合并为一行
517
+ todo_content = "参加会议"
518
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
519
+ location = todo_item[message_id]["location"]
520
+
521
+ # 合并所有信息到任务内容中
522
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
523
+
524
+ output_for_df = []
525
+ output_for_df.append([1, combined_content, "重要"])
526
+
527
+ return output_for_df
528
+
529
+ # 如果是物流取件类消息
530
+ elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
531
+ # 提取可能的取件码
532
+ pickup_code = None
533
+ code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
534
+ code_match = re.search(code_pattern, text_input)
535
+
536
+ todo_content = "取快递"
537
+ if code_match:
538
+ pickup_code = code_match.group(1)
539
+ todo_content = f"取快递(取件码:{pickup_code})"
540
+
541
+ todo_item = {
542
+ message_id: {
543
+ "is_todo": True,
544
+ "end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
545
+ "location": "线下:快递柜",
546
+ "todo_content": todo_content,
547
+ "urgency": "important"
548
+ }
549
+ }
550
+
551
+ # 转换为表格显示格式 - 合并为一行
552
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
553
+ location = todo_item[message_id]["location"]
554
+
555
+ # 合并所有信息到任务内容中
556
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地���: {location})"
557
+
558
+ output_for_df = []
559
+ output_for_df.append([1, combined_content, "重要"])
560
+
561
+ return output_for_df
562
+
563
+ # 对于其他类型的消息,调用LLM API进行处理
564
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
565
+ logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
566
+ return [["error", "Filter API configuration incomplete", "-"]]
567
+
568
+ headers = {
569
+ "Authorization": f"Bearer {Filter_API_KEY}",
570
+ "Accept": "application/json"
571
+ }
572
+ payload = {
573
+ "model": Filter_MODEL_NAME,
574
+ "messages": llm_messages,
575
+ "temperature": 0.2, # 降低温度以提高一致性
576
+ "top_p": 0.95,
577
+ "max_tokens": 1024,
578
+ "stream": False
579
+ }
580
+
581
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
582
+
583
+ try:
584
+ response = requests.post(api_url, headers=headers, json=payload)
585
+ response.raise_for_status() # 检查 HTTP 错误
586
+ raw_llm_response = response.json()['choices'][0]['message']['content']
587
+ logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
588
+ except requests.exceptions.RequestException as e:
589
+ logger.error(f"调用 Filter API 失败: {e}")
590
+ return [["error", f"Filter API call failed: {e}", "-"]]
591
+
592
+ # 解析 LLM 响应
593
+ parsed_todos_data = json_parser(raw_llm_response)
594
+
595
+ if "error" in parsed_todos_data:
596
+ logger.error(f"解析 LLM 响应失败: {parsed_todos_data['error']}")
597
+ return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
598
+
599
+ # 处理解析后的数据
600
+ output_for_df = []
601
+
602
+ # 如果是字典格式(符合prompt模板输出格式)
603
+ if isinstance(parsed_todos_data, dict):
604
+ # 获取消息ID对应的待办信息
605
+ todo_info = None
606
+ for key, value in parsed_todos_data.items():
607
+ if key == message_id or key == str(message_id):
608
+ todo_info = value
609
+ break
610
+
611
+ if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
612
+ # 提取待办信息
613
+ todo_content = todo_info.get("todo_content", "未指定待办内容")
614
+ end_time = todo_info.get("end_time")
615
+ location = todo_info.get("location")
616
+ urgency = todo_info.get("urgency", "unimportant")
617
+
618
+ # 准备合并显示的内容
619
+ combined_content = todo_content
620
+
621
+ # 添加截止时间
622
+ if end_time and end_time != "null":
623
+ try:
624
+ date_part = end_time.split("T")[0] if "T" in end_time else end_time
625
+ combined_content += f" (截止时间: {date_part}"
626
+ except:
627
+ combined_content += f" (截止时间: {end_time}"
628
+ else:
629
+ combined_content += " ("
630
+
631
+ # 添加地点
632
+ if location and location != "null":
633
+ combined_content += f", 地点: {location})"
634
+ else:
635
+ combined_content += ")"
636
+
637
+ # 添加紧急程度
638
+ urgency_display = "一般"
639
+ if urgency == "urgent":
640
+ urgency_display = "紧急"
641
+ elif urgency == "important":
642
+ urgency_display = "重要"
643
+
644
+ # 创建单行输出
645
+ output_for_df = []
646
+ output_for_df.append([1, combined_content, urgency_display])
647
+ else:
648
+ # 不是待办事项
649
+ output_for_df = []
650
+ output_for_df.append([1, "此消息不包含待办事项", "-"])
651
+
652
+ # 如果是旧格式(列表格式)
653
+ elif isinstance(parsed_todos_data, list):
654
+ output_for_df = []
655
+
656
+ # 检查列表是否为空
657
+ if not parsed_todos_data:
658
+ logger.warning("LLM 返回了空列表,无法生成 ToDo 项目")
659
+ return [[1, "未能生成待办事项", "-"]]
660
+
661
+ for i, item in enumerate(parsed_todos_data):
662
+ if isinstance(item, dict):
663
+ todo_content = item.get('todo_content', item.get('content', 'N/A'))
664
+ status = item.get('status', '未完成')
665
+ urgency = item.get('urgency', 'normal')
666
+
667
+ # 合并所有信息到一行
668
+ combined_content = todo_content
669
+
670
+ # 添加截止时间
671
+ if 'end_time' in item and item['end_time']:
672
+ try:
673
+ if isinstance(item['end_time'], str):
674
+ date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
675
+ combined_content += f" (截止时间: {date_part}"
676
+ else:
677
+ combined_content += f" (截止时间: {str(item['end_time'])}"
678
+ except Exception as e:
679
+ logger.warning(f"处理end_time时出错: {e}")
680
+ combined_content += " ("
681
+ else:
682
+ combined_content += " ("
683
+
684
+ # 添加地点
685
+ if 'location' in item and item['location']:
686
+ combined_content += f", 地点: {item['location']})"
687
+ else:
688
+ combined_content += ")"
689
+
690
+ # 设置重要等级
691
+ importance = "一般"
692
+ if urgency == "urgent":
693
+ importance = "紧急"
694
+ elif urgency == "important":
695
+ importance = "重要"
696
+
697
+ output_for_df.append([i + 1, combined_content, importance])
698
+ else:
699
+ # 如果不是字典,转换为字符串并添加到列表
700
+ try:
701
+ item_str = str(item) if item is not None else "未知项目"
702
+ output_for_df.append([i + 1, item_str, "一般"])
703
+ except Exception as e:
704
+ logger.warning(f"处理非字典项目时出错: {e}")
705
+ output_for_df.append([i + 1, "处理错误的项目", "一般"])
706
+
707
+ if not output_for_df:
708
+ logger.info("LLM 解析结果为空或无法转换为DataFrame格式。")
709
+ return [["info", "未发现待办事项", "-"]]
710
+
711
+ return output_for_df
712
+
713
+ except Exception as e:
714
+ logger.exception(f"调用 LLM 或解析时发生错误 (generate_todolist_from_text)")
715
+ return [["error", f"LLM call/parse error: {str(e)}", "-"]]
716
+
717
+ #gradio
718
+ def process(audio, image, request: gr.Request):
719
+ """处理语音和图片的示例函数"""
720
+ # 获取并记录客户端IP
721
+ client_ip = get_client_ip(request, True)
722
+ print(f"Processing audio/image request from IP: {client_ip}")
723
+
724
+ if audio is not None:
725
+ sample_rate, audio_data = audio
726
+ audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
727
+ else:
728
+ audio_info = "未收到音频"
729
+
730
+ if image is not None:
731
+ image_info = f"图片尺寸: {image.shape}"
732
+ else:
733
+ image_info = "未收到图片"
734
+
735
+ return audio_info, image_info
736
+
737
+ def respond(
738
+ message,
739
+ history: list[tuple[str, str]],
740
+ system_message,
741
+ max_tokens,
742
+ temperature,
743
+ top_p,
744
+ audio, # 多模态输入:音频
745
+ image # 多模态输入:图片
746
+ ):
747
+ # ... (聊天回复逻辑基本保持不变, 但确保 client 使用的是配置好的 HF client)
748
+ # 1. 多模态处理接口 (其他人负责)
749
+ # processed_text_from_multimodal = multimodal_placeholder_function(audio, image)
750
+ # 多模态处理:调用讯飞API进行语音和图像识别
751
+ multimodal_content = ""
752
+
753
+ # 多模态处理配置已移至具体处理部分
754
+
755
+ if audio is not None:
756
+ try:
757
+ audio_sample_rate, audio_data = audio
758
+ multimodal_content += f"\n[音频信息: 采样率 {audio_sample_rate}Hz, 时长 {len(audio_data)/audio_sample_rate:.2f}秒]"
759
+
760
+ # 调用Azure Speech语音识别
761
+ azure_speech_config = get_hf_azure_speech_config()
762
+ azure_speech_key = azure_speech_config.get('key')
763
+ azure_speech_region = azure_speech_config.get('region')
764
+
765
+ if azure_speech_key and azure_speech_region:
766
+ import tempfile
767
+ import soundfile as sf
768
+ import os
769
+
770
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
771
+ sf.write(temp_audio.name, audio_data, audio_sample_rate)
772
+ temp_audio_path = temp_audio.name
773
+
774
+ audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, temp_audio_path)
775
+ if audio_text:
776
+ multimodal_content += f"\n[语音识别结果: {audio_text}]"
777
+ else:
778
+ multimodal_content += "\n[语音识别失败]"
779
+
780
+ os.unlink(temp_audio_path)
781
+ else:
782
+ multimodal_content += "\n[Azure Speech API配置不完整,无法进行语音识别]"
783
+
784
+ except Exception as e:
785
+ multimodal_content += f"\n[音频处理错误: {str(e)}]"
786
+
787
+ if image is not None:
788
+ try:
789
+ multimodal_content += f"\n[图片信息: 尺寸 {image.shape}]"
790
+
791
+ # 调用讯飞图像识别
792
+ if xunfei_appid and xunfei_apikey and xunfei_apisecret:
793
+ import tempfile
794
+ from PIL import Image
795
+ import os
796
+
797
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
798
+ if len(image.shape) == 3: # RGB图像
799
+ pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
800
+ else: # 灰度图像
801
+ pil_image = Image.fromarray(image.astype('uint8'), 'L')
802
+
803
+ pil_image.save(temp_image.name, 'JPEG')
804
+ temp_image_path = temp_image.name
805
+
806
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=temp_image_path)
807
+ if image_text:
808
+ multimodal_content += f"\n[图像识别结果: {image_text}]"
809
+ else:
810
+ multimodal_content += "\n[图像识别失败]"
811
+
812
+ os.unlink(temp_image_path)
813
+ else:
814
+ multimodal_content += "\n[讯飞API配置不完整,无法进行图像识别]"
815
+
816
+ except Exception as e:
817
+ multimodal_content += f"\n[图像处理错误: {str(e)}]"
818
+
819
+ # 将多模态内容(或其处理结果)与用户文本消息结合
820
+ # combined_message = message
821
+ # if multimodal_content: # 如果有多模态内容,则附加
822
+ # combined_message += "\n" + multimodal_content
823
+ # 为了聊天模型的连贯性,聊天部分可能只使用原始 message
824
+ # 而 ToDoList 生成则使用 combined_message
825
+
826
+ # 聊天回复生成
827
+ chat_messages = [{"role": "system", "content": system_message}]
828
+ for val in history:
829
+ if val[0]:
830
+ chat_messages.append({"role": "user", "content": val[0]})
831
+ if val[1]:
832
+ chat_messages.append({"role": "assistant", "content": val[1]})
833
+ chat_messages.append({"role": "user", "content": message}) # 聊天机器人使用原始消息
834
+
835
+ chat_response_stream = ""
836
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
837
+ logger.error("Filter API 配置不完整,无法调用 LLM333。")
838
+ yield "Filter API 配置不完整,无法提供聊天回复。", []
839
+ return
840
+
841
+ headers = {
842
+ "Authorization": f"Bearer {Filter_API_KEY}",
843
+ "Accept": "application/json"
844
+ }
845
+ payload = {
846
+ "model": Filter_MODEL_NAME,
847
+ "messages": chat_messages,
848
+ "temperature": temperature,
849
+ "top_p": top_p,
850
+ "max_tokens": max_tokens,
851
+ "stream": True # 聊天通常需要流式输出
852
+ }
853
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
854
+
855
+ try:
856
+ response = requests.post(api_url, headers=headers, json=payload, stream=True)
857
+ response.raise_for_status() # 检查 HTTP 错误
858
+
859
+ for chunk in response.iter_content(chunk_size=None):
860
+ if chunk:
861
+ try:
862
+ # NVIDIA API 的流式输出是 SSE 格式,需要解析
863
+ # 每一行以 'data: ' 开头,后面是 JSON
864
+ for line in chunk.decode('utf-8').splitlines():
865
+ if line.startswith('data: '):
866
+ json_data = line[len('data: '):]
867
+ if json_data.strip() == '[DONE]':
868
+ break
869
+ data = json.loads(json_data)
870
+ # 检查 choices 列表是否存在且不为空
871
+ if 'choices' in data and len(data['choices']) > 0:
872
+ token = data['choices'][0]['delta'].get('content', '')
873
+ if token:
874
+ chat_response_stream += token
875
+ yield chat_response_stream, []
876
+ except json.JSONDecodeError:
877
+ logger.warning(f"无法解析流式响应块: {chunk.decode('utf-8')}")
878
+ except Exception as e:
879
+ logger.error(f"处理流式响应时发生错误: {e}")
880
+ yield chat_response_stream + f"\n\n错误: {e}", []
881
+
882
+ except requests.exceptions.RequestException as e:
883
+ logger.error(f"调用 NVIDIA API 失败: {e}")
884
+ yield f"调用 NVIDIA API 失败: {e}", []
885
+
886
+ # 全局变量存储所有待办事项
887
+ all_todos_global = []
888
+
889
+ # 创建自定义的聊天界面
890
+ with gr.Blocks() as app:
891
+ gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
892
+
893
+ with gr.Row():
894
+ with gr.Column(scale=2):
895
+ gr.Markdown("## Chat Interface")
896
+ chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") # 推荐使用 type="messages"
897
+ msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
898
+
899
+ with gr.Row():
900
+ audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
901
+ image_input = gr.Image(label="上传图片", type="numpy")
902
+
903
+ with gr.Accordion("高级设置", open=False):
904
+ system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
905
+ max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") # 增加聊天模型参数范围
906
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
907
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
908
+
909
+ with gr.Row():
910
+ submit_btn = gr.Button("发送", variant="primary")
911
+ clear_btn = gr.Button("清除聊天和ToDo")
912
+
913
+ with gr.Column(scale=1):
914
+ gr.Markdown("## Generated ToDo List")
915
+ todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
916
+ datatype=["number", "str", "str"],
917
+ row_count=(0, "dynamic"),
918
+ col_count=(3, "fixed"),
919
+ label="待办事项列表")
920
+
921
+ def user(user_message, chat_history):
922
+ # 将用户消息添加到聊天记录 (Gradio type="messages" 格式)
923
+ if not chat_history: chat_history = []
924
+ chat_history.append({"role": "user", "content": user_message})
925
+ return "", chat_history
926
+
927
+ def bot_interaction(chat_history, system_message, max_tokens, temperature, top_p, audio, image):
928
+ user_message_for_chat = ""
929
+ if chat_history and chat_history[-1]["role"] == "user":
930
+ user_message_for_chat = chat_history[-1]["content"]
931
+
932
+ # 准备用于 ToDoList 生成的输入文本 (多模态部分由其他人负责)
933
+ text_for_todolist = user_message_for_chat
934
+ # 可以在这里添加从 audio/image 提取文本的逻辑,并附加到 text_for_todolist
935
+ # multimodal_text = process_multimodal_inputs(audio, image) # 假设的函数
936
+ # if multimodal_text:
937
+ # text_for_todolist += "\n" + multimodal_text
938
+
939
+ # 1. 生成聊天回复 (流式)
940
+ # 转换 chat_history 从 [{'role':'user', 'content':'...'}, ...] 到 [('user_msg', 'bot_msg'), ...]
941
+ # respond 函数期望的是 history: list[tuple[str, str]]
942
+ # 但 Gradio type="messages" 的 chatbot.value 是 [{'role': ..., 'content': ...}, ...]
943
+ # 需要转换
944
+ formatted_history_for_respond = []
945
+ temp_user_msg = None
946
+ for item in chat_history[:-1]: #排除最后一条用户消息,因为它会作为当前message传入respond
947
+ if item["role"] == "user":
948
+ temp_user_msg = item["content"]
949
+ elif item["role"] == "assistant" and temp_user_msg is not None:
950
+ formatted_history_for_respond.append((temp_user_msg, item["content"]))
951
+ temp_user_msg = None
952
+ elif item["role"] == "assistant" and temp_user_msg is None: # Bot 先说话的情况
953
+ formatted_history_for_respond.append(("", item["content"]))
954
+
955
+ chat_stream_generator = respond(
956
+ user_message_for_chat,
957
+ formatted_history_for_respond, # 传递转换后的历史
958
+ system_message,
959
+ max_tokens,
960
+ temperature,
961
+ top_p,
962
+ audio,
963
+ image
964
+ )
965
+
966
+ full_chat_response = ""
967
+ current_todos = []
968
+
969
+ for chat_response_part, _ in chat_stream_generator:
970
+ full_chat_response = chat_response_part
971
+ # 更新 chat_history (Gradio type="messages" 格式)
972
+ if chat_history and chat_history[-1]["role"] == "user":
973
+ # 如果最后一条是用户消息,添加机器人回复
974
+ # 但由于是流式,我们可能需要先添加一个空的 assistant 消息,然后更新它
975
+ # 或者,等待流结束后一次性添加
976
+ # 为了简化,我们先假设 respond 返回的是完整回复,或者在循环外更新
977
+ pass # 流式更新 chatbot 在 submit_btn.click 中处理
978
+ yield chat_history + [[None, full_chat_response]], current_todos # 临时做法,需要适配Gradio的流式更新
979
+
980
+ # 流式结束后,更新 chat_history 中的最后一条 assistant 消息
981
+ if chat_history and full_chat_response:
982
+ # 查找最后一条用户消息,在其后添加或更新机器人回复
983
+ # 这种方式对于 type="messages" 更友好
984
+ # 实际上,Gradio 的 chatbot 更新应该在 .then() 中处理,这里先模拟
985
+ # chat_history.append({"role": "assistant", "content": full_chat_response})
986
+ # 这个 yield 应该在 submit_btn.click 的 .then() 中处理 chatbot 的更新
987
+ # 这里我们先专注于 ToDo 生成
988
+ pass # chatbot 更新由 Gradio 机制处理
989
+
990
+ # 2. 聊天回复完成后,生成/更新 ToDoList
991
+ if text_for_todolist:
992
+ # 使用一个唯一的 ID,例如基于时间戳或随机数,如果需要区分不同输入的 ToDo
993
+ message_id_for_todo = f"hf_app_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
994
+ new_todo_items = generate_todolist_from_text(text_for_todolist, message_id_for_todo)
995
+ current_todos = new_todo_items
996
+
997
+ # bot_interaction 应该返回 chatbot 的最终状态和 todolist_df 的数据
998
+ # chatbot 的最终状态是 chat_history + assistant 的回复
999
+ final_chat_history = list(chat_history) # 复制
1000
+ if full_chat_response:
1001
+ final_chat_history.append({"role": "assistant", "content": full_chat_response})
1002
+
1003
+ yield final_chat_history, current_todos
1004
+
1005
+ # 连接事件 (适配 type="messages")
1006
+ # Gradio 的流式更新通常是:
1007
+ # 1. user 函数准备输入,返回 (空输入框, 更新后的聊天记录)
1008
+ # 2. bot_interaction 函数是一个生成器,yield (部分聊天记录, 部分ToDo)
1009
+ # msg.submit 和 submit_btn.click 的 outputs 需要对应 bot_interaction 的 yield
1010
+
1011
+ # 简化版,非流式更新 chatbot,流式更新由 respond 内部的 yield 控制
1012
+ # 但 respond 的 yield 格式 (str, list) 与 bot_interaction (list, list) 不同
1013
+ # 需要调整 respond 的 yield 或 bot_interaction 的处理
1014
+
1015
+ # 调整后的事件处理,以更好地支持流式聊天和ToDo更新
1016
+ def process_filtered_result_for_todo(filtered_result, content, source_type):
1017
+ """处理过滤结果并生成todolist的辅助函数"""
1018
+ todos = []
1019
+
1020
+ if isinstance(filtered_result, dict) and "error" in filtered_result:
1021
+ logger.error(f"{source_type} Filter 模块处理失败: {filtered_result['error']}")
1022
+ todos = [["Error", f"{source_type}: {filtered_result['error']}", "Filter Failed"]]
1023
+ elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
1024
+ logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
1025
+ todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
1026
+ elif isinstance(filtered_result, list):
1027
+ # 处理列表类型的过滤结果
1028
+ category = None
1029
+ if filtered_result:
1030
+ for item in filtered_result:
1031
+ if isinstance(item, dict) and "分类" in item:
1032
+ category = item["分类"]
1033
+ break
1034
+
1035
+ if category == "其他":
1036
+ logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
1037
+ todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
1038
+ else:
1039
+ logger.info(f"{source_type}消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
1040
+ if content:
1041
+ msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1042
+ todos = generate_todolist_from_text(content, msg_id_todo)
1043
+ # 为每个todo添加来源标识
1044
+ for todo in todos:
1045
+ if len(todo) > 1:
1046
+ todo[1] = f"[{source_type}] {todo[1]}"
1047
+ else:
1048
+ # 如果是字典但不是"其他"分类
1049
+ logger.info(f"{source_type}消息被 Filter 模块归类为 '{filtered_result.get('分类') if isinstance(filtered_result, dict) else '未知'}',继续生成 ToDo List。")
1050
+ if content:
1051
+ msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1052
+ todos = generate_todolist_from_text(content, msg_id_todo)
1053
+ # 为每个todo添加来源标识
1054
+ for todo in todos:
1055
+ if len(todo) > 1:
1056
+ todo[1] = f"[{source_type}] {todo[1]}"
1057
+
1058
+ return todos
1059
+
1060
+ def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f, request: gr.Request):
1061
+ global all_todos_global
1062
+
1063
+ # 获取并记录客户端IP
1064
+ client_ip = get_client_ip(request, True)
1065
+ print(f"Processing request from IP: {client_ip}")
1066
+
1067
+ # 首先处理多模态输入,获取多模态内容
1068
+ multimodal_text_content = ""
1069
+ # 添加调试日志
1070
+ logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
1071
+
1072
+ # 获取Azure Speech配置
1073
+ azure_speech_config = get_hf_azure_speech_config()
1074
+ azure_speech_key = azure_speech_config.get('key')
1075
+ azure_speech_region = azure_speech_config.get('region')
1076
+
1077
+ # 添加调试日志
1078
+ logger.info(f"Azure Speech配置状态 - key: {bool(azure_speech_key)}, region: {bool(azure_speech_region)}")
1079
+
1080
+ # 处理音频输入(使用Azure Speech服务)
1081
+ if audio_f is not None and azure_speech_key and azure_speech_region:
1082
+ logger.info("开始处理音频输入...")
1083
+ try:
1084
+ audio_sample_rate, audio_data = audio_f
1085
+ logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
1086
+
1087
+ # 保存音频为.wav文件
1088
+ audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
1089
+ save_audio(audio_f, audio_filename)
1090
+ logger.info(f"音频已保存: {audio_filename}")
1091
+
1092
+ # 调用Azure Speech服务处理音频
1093
+ audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
1094
+ logger.info(f"音频识别结果: {audio_text}")
1095
+ if audio_text:
1096
+ multimodal_text_content += f"音频内容: {audio_text}"
1097
+ logger.info("音频处理完成")
1098
+ else:
1099
+ logger.warning("音频处理失败")
1100
+ except Exception as e:
1101
+ logger.error(f"音频处理错误: {str(e)}")
1102
+ elif audio_f is not None:
1103
+ logger.warning("音频文件存在但Azure Speech配置不完整,跳过音频处理")
1104
+
1105
+ # 处理图像输入(使用Azure Computer Vision服务)
1106
+ if image_f is not None:
1107
+ logger.info("开始处理图像输入...")
1108
+ try:
1109
+ logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
1110
+
1111
+ # 保存图片为.jpg文件
1112
+ image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
1113
+ save_image(image_f, image_filename)
1114
+ logger.info(f"图像已保存: {image_filename}")
1115
+
1116
+ # 调用tools.py中的image_to_str方法处理图片
1117
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
1118
+ logger.info(f"图像识别结果: {image_text}")
1119
+ if image_text:
1120
+ if multimodal_text_content: # 如果已有音频内容,添加分隔符
1121
+ multimodal_text_content += "\n"
1122
+ multimodal_text_content += f"图像内容: {image_text}"
1123
+ logger.info("图像处理完成")
1124
+ else:
1125
+ logger.warning("图像处理失败")
1126
+ except Exception as e:
1127
+ logger.error(f"图像处理错误: {str(e)}")
1128
+ elif image_f is not None:
1129
+ logger.warning("图像文件存在但处理失败,跳过图像处理")
1130
+
1131
+ # 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
1132
+ final_user_content = user_msg_content.strip() if user_msg_content else ""
1133
+ if not final_user_content and multimodal_text_content:
1134
+ final_user_content = multimodal_text_content
1135
+ logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
1136
+ elif final_user_content and multimodal_text_content:
1137
+ # 用户有文本输入,多模态内容作为补充
1138
+ final_user_content = f"{final_user_content}\n{multimodal_text_content}"
1139
+ logger.info(f"用户有文本输入,多模态内容作为补充")
1140
+
1141
+ # 如果最终还是没有任何内容,提供默认提示
1142
+ if not final_user_content:
1143
+ final_user_content = "[无输入内容]"
1144
+ logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
1145
+
1146
+ logger.info(f"最终用户输入内容: {final_user_content}")
1147
+
1148
+ # 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
1149
+ if not ch_history: ch_history = []
1150
+ ch_history.append({"role": "user", "content": final_user_content})
1151
+ yield ch_history, [] # 更新聊天,ToDo 列表暂时不变
1152
+
1153
+ # 2. 流式生成机器人回复并更新聊天记录
1154
+ # 转换 chat_history 为 respond 函数期望的格式
1155
+ formatted_hist_for_respond = []
1156
+ temp_user_msg_for_hist = None
1157
+ # 使用 ch_history[:-1] 因为当前用户消息已在 ch_history 中
1158
+ for item_hist in ch_history[:-1]:
1159
+ if item_hist["role"] == "user":
1160
+ temp_user_msg_for_hist = item_hist["content"]
1161
+ elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
1162
+ formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
1163
+ temp_user_msg_for_hist = None
1164
+ elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
1165
+ formatted_hist_for_respond.append(("", item_hist["content"]))
1166
+
1167
+ # 准备一个 assistant 消息的槽位
1168
+ ch_history.append({"role": "assistant", "content": ""})
1169
+
1170
+ full_bot_response = ""
1171
+ # 使用最终确定的用户内容进行对话
1172
+ for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
1173
+ full_bot_response = bot_response_token
1174
+ ch_history[-1]["content"] = full_bot_response # 更新最后一条 assistant 消息
1175
+ yield ch_history, [] # 流式更新聊天,ToDo 列表不变
1176
+
1177
+ # 3. 生成 ToDoList - 分别处理音频、图片和文字输入
1178
+ new_todos_list = []
1179
+
1180
+ # 分别处理文字输入
1181
+ if user_msg_content.strip():
1182
+ logger.info(f"处理文字输入生成ToDo: {user_msg_content.strip()}")
1183
+ text_filtered_result = filter_message_with_llm(user_msg_content.strip())
1184
+ text_todos = process_filtered_result_for_todo(text_filtered_result, user_msg_content.strip(), "文字")
1185
+ new_todos_list.extend(text_todos)
1186
+
1187
+ # 分别处理音频输入
1188
+ if audio_f is not None and azure_speech_key and azure_speech_region:
1189
+ try:
1190
+ audio_sample_rate, audio_data = audio_f
1191
+ audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
1192
+ save_audio(audio_f, audio_filename)
1193
+ audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
1194
+ if audio_text:
1195
+ logger.info(f"处理音频输入生成ToDo: {audio_text}")
1196
+ audio_filtered_result = filter_message_with_llm(audio_text)
1197
+ audio_todos = process_filtered_result_for_todo(audio_filtered_result, audio_text, "音频")
1198
+ new_todos_list.extend(audio_todos)
1199
+ except Exception as e:
1200
+ logger.error(f"音频处理错误: {str(e)}")
1201
+
1202
+ # 分别处理图片输入
1203
+ if image_f is not None:
1204
+ try:
1205
+ image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
1206
+ save_image(image_f, image_filename)
1207
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
1208
+ if image_text:
1209
+ logger.info(f"处理图片输入生成ToDo: {image_text}")
1210
+ image_filtered_result = filter_message_with_llm(image_text)
1211
+ image_todos = process_filtered_result_for_todo(image_filtered_result, image_text, "图片")
1212
+ new_todos_list.extend(image_todos)
1213
+ except Exception as e:
1214
+ logger.error(f"图片处理错误: {str(e)}")
1215
+
1216
+ # 如果没有任何有效输入,使用原有逻辑
1217
+ if not new_todos_list and final_user_content:
1218
+ logger.info(f"使用整合内容生成ToDo: {final_user_content}")
1219
+ filtered_result = filter_message_with_llm(final_user_content)
1220
+
1221
+ if isinstance(filtered_result, dict) and "error" in filtered_result:
1222
+ logger.error(f"Filter 模块处理失败: {filtered_result['error']}")
1223
+ # 可以选择在这里显示错误信息给用户
1224
+ new_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
1225
+ elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
1226
+ logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
1227
+ new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
1228
+ elif isinstance(filtered_result, list):
1229
+ # 如果返回的是列表,尝试从列表中获取分类信息
1230
+ category = None
1231
+
1232
+ # 检查列表是否为空
1233
+ if not filtered_result:
1234
+ logger.warning("Filter 模块返回了空列表,将继续生成 ToDo List。")
1235
+ if final_user_content:
1236
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1237
+ new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
1238
+ # 将新的待办事项添加到全局列表中
1239
+ if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
1240
+ # 重新分配ID以确保连续性
1241
+ for i, todo in enumerate(new_todos_list):
1242
+ todo[0] = len(all_todos_global) + i + 1
1243
+ all_todos_global.extend(new_todos_list)
1244
+ yield ch_history, all_todos_global
1245
+ return
1246
+
1247
+ # 确保列表中至少有一个元素且是字典类型
1248
+ valid_item = None
1249
+ for item in filtered_result:
1250
+ if isinstance(item, dict):
1251
+ valid_item = item
1252
+ if "分类" in item:
1253
+ category = item["分类"]
1254
+ break
1255
+
1256
+ # 如果没有找到有效的字典元素,记录警告并继续生成ToDo
1257
+ if valid_item is None:
1258
+ logger.warning(f"Filter 模块返回的列表中没有有效的字典元素: {filtered_result}")
1259
+ if final_user_content:
1260
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1261
+ new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
1262
+ # 将新的待办事项添加到全局列表中
1263
+ if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
1264
+ # 重新分配ID以确保连续性
1265
+ for i, todo in enumerate(new_todos_list):
1266
+ todo[0] = len(all_todos_global) + i + 1
1267
+ all_todos_global.extend(new_todos_list)
1268
+ yield ch_history, all_todos_global
1269
+ return
1270
+
1271
+ if category == "其他":
1272
+ logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
1273
+ new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
1274
+ else:
1275
+ logger.info(f"消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
1276
+ # 如果 Filter 结果不是"其他",则继续生成 ToDoList
1277
+ if final_user_content:
1278
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1279
+ new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
1280
+ else:
1281
+ # 如果是字典但不是"其他"分类
1282
+ logger.info(f"消息被 Filter 模块归类为 '{filtered_result.get('分类')}',继续生成 ToDo List。")
1283
+ # 如果 Filter 结果不是"其他",则继续生成 ToDoList
1284
+ if final_user_content:
1285
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
1286
+ new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
1287
+
1288
+ # 将新的待办事项添加到全局列表中(排除信息性消息)
1289
+ if new_todos_list and not (len(new_todos_list) == 1 and ("Info" in str(new_todos_list[0]) or "Error" in str(new_todos_list[0]))):
1290
+ # 重新分配ID以确保连续性
1291
+ for i, todo in enumerate(new_todos_list):
1292
+ todo[0] = len(all_todos_global) + i + 1
1293
+ all_todos_global.extend(new_todos_list)
1294
+
1295
+ yield ch_history, all_todos_global # 最终更新聊天和完整的ToDo列表
1296
+
1297
+ submit_btn.click(
1298
+ handle_submit,
1299
+ [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
1300
+ [chatbot, todolist_df]
1301
+ )
1302
+ msg.submit(
1303
+ handle_submit,
1304
+ [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
1305
+ [chatbot, todolist_df]
1306
+ )
1307
+
1308
+ def clear_all():
1309
+ global all_todos_global
1310
+ all_todos_global = [] # 清除全局待办事项列表
1311
+ return None, None, "" # 清除 chatbot, todolist_df, 和 msg 输入框
1312
+ clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
1313
+
1314
+ # 旧的 Audio/Image Processing Tab (保持不变或按需修改)
1315
+ with gr.Tab("Audio/Image Processing (Original)"):
1316
+ gr.Markdown("## 处理音频和图片")
1317
+ audio_processor = gr.Audio(label="上传音频", type="numpy")
1318
+ image_processor = gr.Image(label="上传图片", type="numpy")
1319
+ process_btn = gr.Button("处理", variant="primary")
1320
+ audio_output = gr.Textbox(label="音频信息")
1321
+ image_output = gr.Textbox(label="图片信息")
1322
+
1323
+ process_btn.click(
1324
+ process,
1325
+ inputs=[audio_processor, image_processor],
1326
+ outputs=[audio_output, image_output]
1327
+ )
1328
+
1329
+ if __name__ == "__main__":
1330
+ app.launch(debug=True)
app_pro.py ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from pathlib import Path
4
+ import yaml
5
+ import re
6
+ import logging
7
+ import io
8
+ import sys
9
+ import re
10
+ from datetime import datetime, timezone, timedelta
11
+ import requests
12
+ from tools import * #gege的多模态
13
+
14
+
15
+ CONFIG = None
16
+ HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
17
+
18
+ def load_hf_config():
19
+ """加载YAML配置文件"""
20
+ global CONFIG
21
+ if CONFIG is None:
22
+ try:
23
+ with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
24
+ CONFIG = yaml.safe_load(f)
25
+ print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
26
+ except FileNotFoundError:
27
+ print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
28
+ CONFIG = {}
29
+ except Exception as e:
30
+ print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
31
+ CONFIG = {}
32
+ return CONFIG
33
+
34
+ def get_hf_openai_config():
35
+ """获取OpenAI API配置"""
36
+ config = load_hf_config()
37
+ return config.get('openai', {})
38
+
39
+ def get_hf_openai_filter_config():
40
+ """获取Filter API配置"""
41
+ config = load_hf_config()
42
+ return config.get('openai_filter', {})
43
+
44
+ def get_hf_xunfei_config():
45
+ """获取讯飞API配置"""
46
+ config = load_hf_config()
47
+ return config.get('xunfei', {})
48
+
49
+ def get_hf_paths_config():
50
+ """获取文件路径配置"""
51
+ config = load_hf_config()
52
+ base = Path(__file__).resolve().parent
53
+ paths_cfg = config.get('paths', {})
54
+ return {
55
+ 'base_dir': base,
56
+ 'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
57
+ 'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
58
+ 'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
59
+ }
60
+
61
+ llm_config = get_hf_openai_config()
62
+ NVIDIA_API_BASE_URL = llm_config.get('base_url')
63
+ NVIDIA_API_KEY = llm_config.get('api_key')
64
+ NVIDIA_MODEL_NAME = llm_config.get('model')
65
+
66
+ filter_config = get_hf_openai_filter_config()
67
+ Filter_API_BASE_URL = filter_config.get('base_url_filter')
68
+ Filter_API_KEY = filter_config.get('api_key_filter')
69
+ Filter_MODEL_NAME = filter_config.get('model_filter')
70
+
71
+ if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
72
+ print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
73
+ NVIDIA_API_BASE_URL = ""
74
+ NVIDIA_API_KEY = ""
75
+ NVIDIA_MODEL_NAME = ""
76
+
77
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
78
+ print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
79
+ Filter_API_BASE_URL = ""
80
+ Filter_API_KEY = ""
81
+ Filter_MODEL_NAME = ""
82
+
83
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
84
+ logger = logging.getLogger(__name__)
85
+
86
+ def load_single_few_shot_file_hf(file_path: Path) -> str:
87
+ """加载单个few-shot示例文件并转义大括号"""
88
+ try:
89
+ with open(file_path, 'r', encoding='utf-8') as f:
90
+ content = f.read()
91
+ escaped_content = content.replace('{', '{{').replace('}', '}}')
92
+ return escaped_content
93
+ except FileNotFoundError:
94
+ return ""
95
+ except Exception:
96
+ return ""
97
+
98
+ PROMPT_TEMPLATE_CONTENT = ""
99
+ TRUE_POSITIVE_EXAMPLES_CONTENT = ""
100
+ FALSE_POSITIVE_EXAMPLES_CONTENT = ""
101
+
102
+ def load_prompt_data_hf():
103
+ """加载提示词模板和示例数据"""
104
+ global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
105
+ paths = get_hf_paths_config()
106
+ try:
107
+ with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
108
+ PROMPT_TEMPLATE_CONTENT = f.read()
109
+ except FileNotFoundError:
110
+ PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
111
+
112
+ TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
113
+ FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
114
+
115
+ load_prompt_data_hf()
116
+
117
+ def _process_parsed_json(parsed_data):
118
+ """处理解析后的JSON数据,确保格式正确"""
119
+ try:
120
+ if isinstance(parsed_data, list):
121
+ if not parsed_data:
122
+ return [{}]
123
+
124
+ processed_list = []
125
+ for item in parsed_data:
126
+ if isinstance(item, dict):
127
+ processed_list.append(item)
128
+ else:
129
+ try:
130
+ processed_list.append({"content": str(item)})
131
+ except:
132
+ processed_list.append({"content": "无法转换的项目"})
133
+
134
+ if not processed_list:
135
+ return [{}]
136
+
137
+ return processed_list
138
+
139
+ elif isinstance(parsed_data, dict):
140
+ return parsed_data
141
+
142
+ else:
143
+ return {"content": str(parsed_data)}
144
+
145
+ except Exception as e:
146
+ return {"error": f"Error processing parsed JSON: {e}"}
147
+
148
+ def json_parser(text: str) -> dict:
149
+ """从文本中解析JSON数据,支持多种格式"""
150
+ try:
151
+ try:
152
+ parsed_data = json.loads(text)
153
+ return _process_parsed_json(parsed_data)
154
+ except json.JSONDecodeError:
155
+ pass
156
+
157
+ match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
158
+ if match:
159
+ json_str = match.group(1).strip()
160
+ json_str = re.sub(r',\s*]', ']', json_str)
161
+ json_str = re.sub(r',\s*}', '}', json_str)
162
+ try:
163
+ parsed_data = json.loads(json_str)
164
+ return _process_parsed_json(parsed_data)
165
+ except json.JSONDecodeError:
166
+ pass
167
+
168
+ array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
169
+ if array_match:
170
+ potential_json = array_match.group(0).strip()
171
+ try:
172
+ parsed_data = json.loads(potential_json)
173
+ return _process_parsed_json(parsed_data)
174
+ except json.JSONDecodeError:
175
+ pass
176
+
177
+ object_match = re.search(r'\{.*?\}', text, re.DOTALL)
178
+ if object_match:
179
+ potential_json = object_match.group(0).strip()
180
+ try:
181
+ parsed_data = json.loads(potential_json)
182
+ return _process_parsed_json(parsed_data)
183
+ except json.JSONDecodeError:
184
+ pass
185
+
186
+ return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
187
+
188
+ except Exception as e:
189
+ return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
190
+
191
+ def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
192
+ """使用LLM对消息进行分类过滤"""
193
+ mock_data = [(text_input, message_id)]
194
+
195
+ system_prompt = """
196
+ # 角色
197
+ 你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
198
+
199
+ # 任务
200
+ 对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
201
+ 主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
202
+
203
+ # 要求
204
+ 1. 以json格式输出
205
+ 2. content简洁提炼关键词,字符数<20以内
206
+ 3. 输入条数和输出条数完全一样
207
+
208
+ # 输出示例
209
+ ```
210
+ [
211
+ {"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
212
+ {"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
213
+ ]
214
+ ```
215
+ """
216
+
217
+ llm_messages = [
218
+ {"role": "system", "content": system_prompt},
219
+ {"role": "user", "content": str(mock_data)}
220
+ ]
221
+
222
+ try:
223
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
224
+ return [{"error": "Filter API configuration incomplete", "-": "-"}]
225
+
226
+ headers = {
227
+ "Authorization": f"Bearer {Filter_API_KEY}",
228
+ "Accept": "application/json"
229
+ }
230
+ payload = {
231
+ "model": Filter_MODEL_NAME,
232
+ "messages": llm_messages,
233
+ "temperature": 0.0,
234
+ "top_p": 0.95,
235
+ "max_tokens": 1024,
236
+ "stream": False
237
+ }
238
+
239
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
240
+
241
+ try:
242
+ response = requests.post(api_url, headers=headers, json=payload)
243
+ response.raise_for_status()
244
+ raw_llm_response = response.json()["choices"][0]["message"]["content"]
245
+ except requests.exceptions.RequestException as e:
246
+ return [{"error": f"Filter API call failed: {e}", "-": "-"}]
247
+
248
+ raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
249
+ parsed_filter_data = json_parser(raw_llm_response)
250
+
251
+ if "error" in parsed_filter_data:
252
+ return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
253
+
254
+ if isinstance(parsed_filter_data, list) and parsed_filter_data:
255
+ for item in parsed_filter_data:
256
+ if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
257
+ item["分类"] = "其他"
258
+
259
+ request_id_list = {message_id}
260
+ response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
261
+ diff = request_id_list - response_id_list
262
+
263
+ if diff:
264
+ for missed_id in diff:
265
+ parsed_filter_data.append({
266
+ "message_id": missed_id,
267
+ "content": text_input[:20],
268
+ "物流取件": 0,
269
+ "欠费缴纳": 0,
270
+ "待付(还)款": 0,
271
+ "会议邀约": 0,
272
+ "其他": 100,
273
+ "分类": "其他"
274
+ })
275
+
276
+ return parsed_filter_data
277
+ else:
278
+ return [{
279
+ "message_id": message_id,
280
+ "content": text_input[:20],
281
+ "物流取件": 0,
282
+ "欠费缴纳": 0,
283
+ "待付(还)款": 0,
284
+ "会议邀约": 0,
285
+ "其他": 100,
286
+ "分类": "其他",
287
+ "error": "Filter LLM returned empty or unexpected format"
288
+ }]
289
+
290
+ except Exception as e:
291
+ return [{
292
+ "message_id": message_id,
293
+ "content": text_input[:20],
294
+ "物流取件": 0,
295
+ "欠费缴纳": 0,
296
+ "待付(还)款": 0,
297
+ "会议邀约": 0,
298
+ "其他": 100,
299
+ "分类": "其他",
300
+ "error": f"Filter LLM call/parse error: {str(e)}"
301
+ }]
302
+
303
+ def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
304
+ """从文本生成待办事项列表"""
305
+ if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
306
+ return [["error", "Prompt template not loaded", "-"]]
307
+
308
+ current_time_iso = datetime.now(timezone.utc).isoformat()
309
+ content_escaped = text_input.replace('{', '{{').replace('}', '}}')
310
+
311
+ formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
312
+ true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
313
+ false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
314
+ current_time=current_time_iso,
315
+ message_id=message_id,
316
+ content_escaped=content_escaped
317
+ )
318
+
319
+ enhanced_prompt = formatted_prompt + """
320
+
321
+ # 重要提示
322
+ 请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
323
+ 你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
324
+ """
325
+
326
+ llm_messages = [
327
+ {"role": "user", "content": enhanced_prompt}
328
+ ]
329
+
330
+ try:
331
+ if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
332
+ todo_item = {
333
+ message_id: {
334
+ "is_todo": True,
335
+ "end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
336
+ "location": "线上:中国移动APP",
337
+ "todo_content": "缴纳话费",
338
+ "urgency": "important"
339
+ }
340
+ }
341
+
342
+ todo_content = "缴纳话费"
343
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
344
+ location = todo_item[message_id]["location"]
345
+
346
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
347
+
348
+ output_for_df = []
349
+ output_for_df.append([1, combined_content, "重要"])
350
+
351
+ return output_for_df
352
+
353
+ elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
354
+ meeting_time = None
355
+ meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
356
+ meeting_match = re.search(meeting_pattern, text_input)
357
+
358
+ if meeting_match:
359
+ meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
360
+ else:
361
+ meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
362
+
363
+ todo_item = {
364
+ message_id: {
365
+ "is_todo": True,
366
+ "end_time": meeting_time,
367
+ "location": "线上:会议软件",
368
+ "todo_content": "参加会议",
369
+ "urgency": "important"
370
+ }
371
+ }
372
+
373
+ todo_content = "参加会议"
374
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
375
+ location = todo_item[message_id]["location"]
376
+
377
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
378
+
379
+ output_for_df = []
380
+ output_for_df.append([1, combined_content, "重要"])
381
+
382
+ return output_for_df
383
+
384
+ elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
385
+ pickup_code = None
386
+ code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
387
+ code_match = re.search(code_pattern, text_input)
388
+
389
+ todo_content = "取快递"
390
+ if code_match:
391
+ pickup_code = code_match.group(1)
392
+ todo_content = f"取快递(取件码:{pickup_code})"
393
+
394
+ todo_item = {
395
+ message_id: {
396
+ "is_todo": True,
397
+ "end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
398
+ "location": "线下:快递柜",
399
+ "todo_content": todo_content,
400
+ "urgency": "important"
401
+ }
402
+ }
403
+
404
+ end_time = todo_item[message_id]["end_time"].split("T")[0]
405
+ location = todo_item[message_id]["location"]
406
+
407
+ combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
408
+
409
+ output_for_df = []
410
+ output_for_df.append([1, combined_content, "重要"])
411
+
412
+ return output_for_df
413
+
414
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
415
+ return [["error", "Filter API configuration incomplete", "-"]]
416
+
417
+ headers = {
418
+ "Authorization": f"Bearer {Filter_API_KEY}",
419
+ "Accept": "application/json"
420
+ }
421
+ payload = {
422
+ "model": Filter_MODEL_NAME,
423
+ "messages": llm_messages,
424
+ "temperature": 0.2,
425
+ "top_p": 0.95,
426
+ "max_tokens": 1024,
427
+ "stream": False
428
+ }
429
+
430
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
431
+
432
+ try:
433
+ response = requests.post(api_url, headers=headers, json=payload)
434
+ response.raise_for_status()
435
+ raw_llm_response = response.json()['choices'][0]['message']['content']
436
+ except requests.exceptions.RequestException as e:
437
+ return [["error", f"Filter API call failed: {e}", "-"]]
438
+
439
+ parsed_todos_data = json_parser(raw_llm_response)
440
+
441
+ if "error" in parsed_todos_data:
442
+ return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
443
+
444
+ output_for_df = []
445
+
446
+ if isinstance(parsed_todos_data, dict):
447
+ todo_info = None
448
+ for key, value in parsed_todos_data.items():
449
+ if key == message_id or key == str(message_id):
450
+ todo_info = value
451
+ break
452
+
453
+ if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
454
+ todo_content = todo_info.get("todo_content", "未指定待办内容")
455
+ end_time = todo_info.get("end_time")
456
+ location = todo_info.get("location")
457
+ urgency = todo_info.get("urgency", "unimportant")
458
+
459
+ combined_content = todo_content
460
+
461
+ if end_time and end_time != "null":
462
+ try:
463
+ date_part = end_time.split("T")[0] if "T" in end_time else end_time
464
+ combined_content += f" (截止时间: {date_part}"
465
+ except:
466
+ combined_content += f" (截止时间: {end_time}"
467
+ else:
468
+ combined_content += " ("
469
+
470
+ if location and location != "null":
471
+ combined_content += f", 地点: {location})"
472
+ else:
473
+ combined_content += ")"
474
+
475
+ urgency_display = "一般"
476
+ if urgency == "urgent":
477
+ urgency_display = "紧急"
478
+ elif urgency == "important":
479
+ urgency_display = "重要"
480
+
481
+ output_for_df = []
482
+ output_for_df.append([1, combined_content, urgency_display])
483
+ else:
484
+ output_for_df = []
485
+ output_for_df.append([1, "此消息不包含待办事项", "-"])
486
+
487
+ elif isinstance(parsed_todos_data, list):
488
+ output_for_df = []
489
+
490
+ if not parsed_todos_data:
491
+ return [[1, "未能生成待办事项", "-"]]
492
+
493
+ for i, item in enumerate(parsed_todos_data):
494
+ if isinstance(item, dict):
495
+ todo_content = item.get('todo_content', item.get('content', 'N/A'))
496
+ status = item.get('status', '未完成')
497
+ urgency = item.get('urgency', 'normal')
498
+
499
+ combined_content = todo_content
500
+
501
+ if 'end_time' in item and item['end_time']:
502
+ try:
503
+ if isinstance(item['end_time'], str):
504
+ date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
505
+ combined_content += f" (截止时间: {date_part}"
506
+ else:
507
+ combined_content += f" (截止时间: {str(item['end_time'])}"
508
+ except Exception:
509
+ combined_content += " ("
510
+ else:
511
+ combined_content += " ("
512
+
513
+ if 'location' in item and item['location']:
514
+ combined_content += f", 地点: {item['location']})"
515
+ else:
516
+ combined_content += ")"
517
+
518
+ importance = "一般"
519
+ if urgency == "urgent":
520
+ importance = "紧急"
521
+ elif urgency == "important":
522
+ importance = "重要"
523
+
524
+ output_for_df.append([i + 1, combined_content, importance])
525
+ else:
526
+ try:
527
+ item_str = str(item) if item is not None else "未知项目"
528
+ output_for_df.append([i + 1, item_str, "一般"])
529
+ except Exception:
530
+ output_for_df.append([i + 1, "处理错误的项目", "一般"])
531
+
532
+ if not output_for_df:
533
+ return [["info", "未发现待办事项", "-"]]
534
+
535
+ return output_for_df
536
+
537
+ except Exception as e:
538
+ return [["error", f"LLM call/parse error: {str(e)}", "-"]]
539
+ # 这里------多模态数据从这里调用
540
+ def process(audio, image):
541
+ """处理音频和图片输入,返回基本信息"""
542
+ if audio is not None:
543
+ sample_rate, audio_data = audio
544
+ audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
545
+ else:
546
+ audio_info = "未收到音频"
547
+
548
+ if image is not None:
549
+ image_info = f"图片尺寸: {image.shape}"
550
+ else:
551
+ image_info = "未收到图片"
552
+
553
+ return audio_info, image_info
554
+
555
+ def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image):
556
+ """处理聊天响应,支持流式输出"""
557
+ chat_messages = [{"role": "system", "content": system_message}]
558
+ for val in history:
559
+ if val[0]:
560
+ chat_messages.append({"role": "user", "content": val[0]})
561
+ if val[1]:
562
+ chat_messages.append({"role": "assistant", "content": val[1]})
563
+ chat_messages.append({"role": "user", "content": message})
564
+
565
+ chat_response_stream = ""
566
+ if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
567
+ yield "Filter API 配置不完整,无法提供聊天回复。", []
568
+ return
569
+
570
+ headers = {
571
+ "Authorization": f"Bearer {Filter_API_KEY}",
572
+ "Accept": "application/json"
573
+ }
574
+ payload = {
575
+ "model": Filter_MODEL_NAME,
576
+ "messages": chat_messages,
577
+ "temperature": temperature,
578
+ "top_p": top_p,
579
+ "max_tokens": max_tokens,
580
+ "stream": True
581
+ }
582
+ api_url = f"{Filter_API_BASE_URL}/chat/completions"
583
+
584
+ try:
585
+ response = requests.post(api_url, headers=headers, json=payload, stream=True)
586
+ response.raise_for_status()
587
+
588
+ for chunk in response.iter_content(chunk_size=None):
589
+ if chunk:
590
+ try:
591
+ for line in chunk.decode('utf-8').splitlines():
592
+ if line.startswith('data: '):
593
+ json_data = line[len('data: '):]
594
+ if json_data.strip() == '[DONE]':
595
+ break
596
+ data = json.loads(json_data)
597
+ token = data['choices'][0]['delta'].get('content', '')
598
+ if token:
599
+ chat_response_stream += token
600
+ yield chat_response_stream, []
601
+ except json.JSONDecodeError:
602
+ pass
603
+ except Exception as e:
604
+ yield chat_response_stream + f"\n\n错误: {e}", []
605
+
606
+ except requests.exceptions.RequestException as e:
607
+ yield f"调用 NVIDIA API 失败: {e}", []
608
+ # 图片-多模态上传入口
609
+ with gr.Blocks() as app:
610
+ gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
611
+
612
+ with gr.Row():
613
+ with gr.Column(scale=2):
614
+ gr.Markdown("## Chat Interface")
615
+ chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages")
616
+ msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
617
+
618
+ with gr.Row():
619
+ audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
620
+ image_input = gr.Image(label="上传图片", type="numpy")
621
+
622
+ with gr.Accordion("高级设置", open=False):
623
+ system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
624
+ max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)")
625
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
626
+ top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
627
+
628
+ with gr.Row():
629
+ submit_btn = gr.Button("发送", variant="primary")
630
+ clear_btn = gr.Button("清除聊天和ToDo")
631
+
632
+ with gr.Column(scale=1):
633
+ gr.Markdown("## Generated ToDo List")
634
+ todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
635
+ datatype=["number", "str", "str"],
636
+ row_count=(0, "dynamic"),
637
+ col_count=(3, "fixed"),
638
+ label="待办事项列表")
639
+
640
+ def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f):
641
+ """处理用户提交的消息,生成聊天回复和待办事项"""
642
+ # 首先处理多模态输入,获取多模态内容
643
+ multimodal_text_content = ""
644
+ xunfei_config = get_hf_xunfei_config()
645
+ xunfei_appid = xunfei_config.get('appid')
646
+ xunfei_apikey = xunfei_config.get('apikey')
647
+ xunfei_apisecret = xunfei_config.get('apisecret')
648
+
649
+ # 添加调试日志
650
+ logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
651
+ logger.info(f"讯飞配置状态 - appid: {bool(xunfei_appid)}, apikey: {bool(xunfei_apikey)}, apisecret: {bool(xunfei_apisecret)}")
652
+
653
+ # 处理音频输入(独立处理)
654
+ if audio_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
655
+ logger.info("开始处理音频输入...")
656
+ try:
657
+ import tempfile
658
+ import soundfile as sf
659
+ import os
660
+
661
+ audio_sample_rate, audio_data = audio_f
662
+ logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
663
+
664
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
665
+ sf.write(temp_audio.name, audio_data, audio_sample_rate)
666
+ temp_audio_path = temp_audio.name
667
+ logger.info(f"音频临时文件已保存: {temp_audio_path}")
668
+
669
+ audio_text = audio_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_audio_path)
670
+ logger.info(f"音频识别结果: {audio_text}")
671
+ if audio_text:
672
+ multimodal_text_content += f"音频内容: {audio_text}"
673
+
674
+ os.unlink(temp_audio_path)
675
+ logger.info("音频处理完成")
676
+ except Exception as e:
677
+ logger.error(f"音频处理错误: {str(e)}")
678
+ elif audio_f is not None:
679
+ logger.warning("音频文件存在但讯飞配置不完整,跳过音频处理")
680
+
681
+ # 处理图像输入(独立处理)
682
+ if image_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
683
+ logger.info("开始处理图像输入...")
684
+ try:
685
+ import tempfile
686
+ from PIL import Image
687
+ import os
688
+
689
+ logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
690
+
691
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
692
+ if len(image_f.shape) == 3: # RGB图像
693
+ pil_image = Image.fromarray(image_f.astype('uint8'), 'RGB')
694
+ else: # 灰度图像
695
+ pil_image = Image.fromarray(image_f.astype('uint8'), 'L')
696
+
697
+ pil_image.save(temp_image.name, 'JPEG')
698
+ temp_image_path = temp_image.name
699
+ logger.info(f"图像临时文件已保存: {temp_image_path}")
700
+
701
+ image_text = image_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_image_path)
702
+ logger.info(f"图像识别结果: {image_text}")
703
+ if image_text:
704
+ if multimodal_text_content: # 如果已有音频内容,添加分隔符
705
+ multimodal_text_content += "\n"
706
+ multimodal_text_content += f"图像内容: {image_text}"
707
+
708
+ os.unlink(temp_image_path)
709
+ logger.info("图像处理完成")
710
+ except Exception as e:
711
+ logger.error(f"图像处理错误: {str(e)}")
712
+ elif image_f is not None:
713
+ logger.warning("图像文件存在但讯飞配置不完整,跳过图像处理")
714
+
715
+ # 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
716
+ final_user_content = user_msg_content.strip() if user_msg_content else ""
717
+ if not final_user_content and multimodal_text_content:
718
+ final_user_content = multimodal_text_content
719
+ logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
720
+ elif final_user_content and multimodal_text_content:
721
+ # 用户有文本输入,多模态内容作为补充
722
+ final_user_content = f"{final_user_content}\n{multimodal_text_content}"
723
+ logger.info(f"用户有文本输入,多模态内容作为补充")
724
+
725
+ # 如果最终还是没有任何内容,提供默认提示
726
+ if not final_user_content:
727
+ final_user_content = "[无输入内容]"
728
+ logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
729
+
730
+ logger.info(f"最终用户输入内容: {final_user_content}")
731
+
732
+ # 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
733
+ if not ch_history: ch_history = []
734
+ ch_history.append({"role": "user", "content": final_user_content})
735
+ yield ch_history, []
736
+
737
+ # 2. 流式生成机器人回复并更新聊天记录
738
+ formatted_hist_for_respond = []
739
+ temp_user_msg_for_hist = None
740
+ for item_hist in ch_history[:-1]:
741
+ if item_hist["role"] == "user":
742
+ temp_user_msg_for_hist = item_hist["content"]
743
+ elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
744
+ formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
745
+ temp_user_msg_for_hist = None
746
+ elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
747
+ formatted_hist_for_respond.append(("", item_hist["content"]))
748
+
749
+ ch_history.append({"role": "assistant", "content": ""})
750
+
751
+ full_bot_response = ""
752
+ # 使用最终确定的用户内容进行对话
753
+ for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
754
+ full_bot_response = bot_response_token
755
+ ch_history[-1]["content"] = full_bot_response
756
+ yield ch_history, []
757
+
758
+ # 3. 生成 ToDoList - 使用最终确定的用户内容
759
+ text_for_todo = final_user_content
760
+
761
+ # 添加日志:输出用于ToDo生成的内容
762
+ logger.info(f"用于ToDo生成的内容: {text_for_todo}")
763
+ current_todos_list = []
764
+
765
+ filtered_result = filter_message_with_llm(text_for_todo)
766
+
767
+ if isinstance(filtered_result, dict) and "error" in filtered_result:
768
+ current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
769
+ elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
770
+ current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
771
+ elif isinstance(filtered_result, list):
772
+ category = None
773
+
774
+ if not filtered_result:
775
+ if text_for_todo:
776
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
777
+ current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
778
+ yield ch_history, current_todos_list
779
+ return
780
+
781
+ valid_item = None
782
+ for item in filtered_result:
783
+ if isinstance(item, dict):
784
+ valid_item = item
785
+ if "分类" in item:
786
+ category = item["分类"]
787
+ break
788
+
789
+ if valid_item is None:
790
+ if text_for_todo:
791
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
792
+ current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
793
+ yield ch_history, current_todos_list
794
+ return
795
+
796
+ if category == "其他":
797
+ current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
798
+ else:
799
+ if text_for_todo:
800
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
801
+ current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
802
+ else:
803
+ if text_for_todo:
804
+ msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
805
+ current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
806
+
807
+ yield ch_history, current_todos_list
808
+
809
+ submit_btn.click(
810
+ handle_submit,
811
+ [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
812
+ [chatbot, todolist_df]
813
+ )
814
+ msg.submit(
815
+ handle_submit,
816
+ [msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
817
+ [chatbot, todolist_df]
818
+ )
819
+
820
+ def clear_all():
821
+ """清除所有聊天记录和待办事项"""
822
+ return None, None, ""
823
+ clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
824
+ #多模态标签也
825
+ with gr.Tab("Audio/Image Processing (Original)"):
826
+ gr.Markdown("## 处理音频和图片")
827
+ audio_processor = gr.Audio(label="上传音频", type="numpy")
828
+ image_processor = gr.Image(label="上传图片", type="numpy")
829
+ process_btn = gr.Button("处理", variant="primary")
830
+ audio_output = gr.Textbox(label="音频信息")
831
+ image_output = gr.Textbox(label="图片信息")
832
+
833
+ process_btn.click(
834
+ process,
835
+ inputs=[audio_processor, image_processor],
836
+ outputs=[audio_output, image_output]
837
+ )
838
+
839
+ if __name__ == "__main__":
840
+ app.launch(debug=False)
audio_127.0.0.1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4cca96c289e5acdfd9d8e926bb40674e170374878d57e4d3c3f5aca3039bec8
3
+ size 1830956
image_127.0.0.1.jpg ADDED
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
- gradio
2
- requests
3
- pathlib
4
- python-dateutil
 
 
 
 
 
1
+ gradio
2
+ requests
3
+ pathlib
4
+ python-dateutil
5
+ Pillow
6
+ numpy
7
+ wave
8
+ azure-ai-inference
se_app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
+ import numpy as np
5
+ from scipy.io.wavfile import write as write_wav
6
+ from PIL import Image
7
+ from tools import audio_to_str, image_to_str # 导入tools.py中的方法
8
+
9
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
10
+
11
+ # 指定保存文件的相对路径
12
+ SAVE_DIR = 'download' # 相对路径
13
+ os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
14
+
15
+ def get_client_ip(request: gr.Request, debug_mode=False):
16
+ """获取客户端真实IP地址"""
17
+ if request:
18
+ # 从请求头中获取真实IP(考虑代理情况)
19
+ x_forwarded_for = request.headers.get("x-forwarded-for", "")
20
+ if x_forwarded_for:
21
+ client_ip = x_forwarded_for.split(",")[0]
22
+ else:
23
+ client_ip = request.client.host
24
+ if debug_mode:
25
+ print(f"Debug: Client IP detected as {client_ip}")
26
+ return client_ip
27
+ return "unknown"
28
+
29
+ def save_audio(audio, filename):
30
+ """保存音频为.wav文件"""
31
+ sample_rate, audio_data = audio
32
+ write_wav(filename, sample_rate, audio_data)
33
+
34
+ def save_image(image, filename):
35
+ """保存图片为.jpg文件"""
36
+ img = Image.fromarray(image.astype('uint8'))
37
+ img.save(filename)
38
+
39
+ def process(audio, image, text, request: gr.Request):
40
+ """处理语音、图片和文本的示例函数"""
41
+ client_ip = get_client_ip(request, True)
42
+ print(f"Processing request from IP: {client_ip}")
43
+
44
+ audio_info = "未收到音频"
45
+ image_info = "未收到图片"
46
+ text_info = "未收到文本"
47
+ audio_filename = None
48
+ image_filename = None
49
+ audio_text = ""
50
+ image_text = ""
51
+
52
+ if audio is not None:
53
+ sample_rate, audio_data = audio
54
+ audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
55
+ # 保存音频为.wav文件
56
+ audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
57
+ save_audio(audio, audio_filename)
58
+ print(f"Audio saved as {audio_filename}")
59
+ # 调用tools.py中的audio_to_str方法处理音频
60
+ audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
61
+ if audio_text:
62
+ print(f"Audio text: {audio_text}")
63
+ else:
64
+ print("Audio processing failed")
65
+
66
+ if image is not None:
67
+ image_info = f"图片尺寸: {image.shape}"
68
+ # 保存图片为.jpg文件
69
+ image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
70
+ save_image(image, image_filename)
71
+ print(f"Image saved as {image_filename}")
72
+ # 调用tools.py中的image_to_str方法处理图片
73
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
74
+ if image_text:
75
+ print(f"Image text: {image_text}")
76
+ else:
77
+ print("Image processing failed")
78
+
79
+ if text:
80
+ text_info = f"接收到文本: {text}"
81
+
82
+ return audio_info, image_info, text_info, audio_text, image_text
83
+
84
+ # 创建自定义的聊天界面
85
+ with gr.Blocks() as app:
86
+ gr.Markdown("# ToDoAgent Multi-Modal Interface")
87
+
88
+ # 创建两个标签页
89
+ with gr.Tab("Chat"):
90
+ # 修复Chatbot类型警告
91
+ chatbot = gr.Chatbot(height=500, type="messages")
92
+
93
+ msg = gr.Textbox(label="输入消息", placeholder="输入您的问题...")
94
+
95
+ # 上传区域
96
+ with gr.Row():
97
+ audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
98
+ image_input = gr.Image(label="上传图片", type="numpy")
99
+
100
+ # 设置区域
101
+ with gr.Accordion("高级设置", open=False):
102
+ system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
103
+ max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大生成长度")
104
+ temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度")
105
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
106
+
107
+ # 提交按钮
108
+ submit_btn = gr.Button("发送", variant="primary")
109
+
110
+ # 清除按钮
111
+ clear = gr.Button("清除聊天")
112
+
113
+ # 事件处理
114
+ def user(user_message, chat_history):
115
+ return "", chat_history + [{"role": "user", "content": user_message}]
116
+ #新增多模态处理--1
117
+ def respond(message, chat_history, system_message, max_tokens, temperature, top_p, audio=None, image=None, text=None, request=None):
118
+ """生成响应的函数"""
119
+ # 处理多模态输入
120
+ multimodal_content = ""
121
+ if audio is not None:
122
+ try:
123
+ audio_filename = os.path.join(SAVE_DIR, "temp_audio.wav")
124
+ save_audio(audio, audio_filename)
125
+ audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
126
+ if audio_text:
127
+ multimodal_content += f"音频内容: {audio_text}\n"
128
+ except Exception as e:
129
+ print(f"Audio processing error: {e}")
130
+
131
+ if image is not None:
132
+ try:
133
+ image_filename = os.path.join(SAVE_DIR, "temp_image.jpg")
134
+ save_image(image, image_filename)
135
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
136
+ if image_text:
137
+ multimodal_content += f"图片内容: {image_text}\n"
138
+ except Exception as e:
139
+ print(f"Image processing error: {e}")
140
+
141
+ # 组合最终消息
142
+ final_message = message
143
+ if multimodal_content:
144
+ final_message = f"{message}\n\n{multimodal_content}"
145
+
146
+ # 构建消息历史
147
+ messages = [{"role": "system", "content": system_message}]
148
+ for chat in chat_history:
149
+ if isinstance(chat, dict) and "role" in chat and "content" in chat:
150
+ messages.append(chat)
151
+
152
+ messages.append({"role": "user", "content": final_message})
153
+
154
+ # 调用HuggingFace API
155
+ try:
156
+ response = client.chat_completion(
157
+ messages,
158
+ max_tokens=max_tokens,
159
+ stream=True,
160
+ temperature=temperature,
161
+ top_p=top_p,
162
+ )
163
+
164
+ partial_message = ""
165
+ for token in response:
166
+ if token.choices[0].delta.content is not None:
167
+ partial_message += token.choices[0].delta.content
168
+ yield partial_message
169
+ except Exception as e:
170
+ yield f"抱歉,生成响应时出现错误: {str(e)}"
171
+
172
+ def bot(chat_history, system_message, max_tokens, temperature, top_p, audio, image, text):
173
+ # 检查chat_history是否为空
174
+ if not chat_history or len(chat_history) == 0:
175
+ return
176
+
177
+ # 获取最后一条用户消息
178
+ last_message = chat_history[-1]
179
+ if not last_message or not isinstance(last_message, dict) or "content" not in last_message:
180
+ return
181
+
182
+ user_message = last_message["content"]
183
+
184
+ # 生成响应
185
+ bot_response = ""
186
+ for response in respond(
187
+ user_message,
188
+ chat_history[:-1],
189
+ system_message,
190
+ max_tokens,
191
+ temperature,
192
+ top_p,
193
+ audio,
194
+ image,
195
+ text
196
+ ):
197
+ bot_response = response
198
+ # 添加助手回复到聊天历史
199
+ updated_history = chat_history + [{"role": "assistant", "content": bot_response}]
200
+ yield updated_history
201
+
202
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
203
+ bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
204
+ )
205
+
206
+ submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
207
+ bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
208
+ )
209
+
210
+ clear.click(lambda: None, None, chatbot, queue=False)
211
+
212
+ with gr.Tab("Audio/Image Processing"):
213
+ gr.Markdown("## 处理音频和图片")
214
+ audio_processor = gr.Audio(label="上传音频", type="numpy")
215
+ image_processor = gr.Image(label="上传图片", type="numpy")
216
+ text_input = gr.Textbox(label="输入文本")
217
+ process_btn = gr.Button("处理", variant="primary")
218
+ audio_output = gr.Textbox(label="音频信息")
219
+ image_output = gr.Textbox(label="图片信息")
220
+ text_output = gr.Textbox(label="文本信息")
221
+ audio_text_output = gr.Textbox(label="音频转文字结果")
222
+ image_text_output = gr.Textbox(label="图片转文字结果")
223
+
224
+ # 修改后的处理函数调用
225
+ process_btn.click(
226
+ process,
227
+ inputs=[audio_processor, image_processor, text_input],
228
+ outputs=[audio_output, image_output, text_output, audio_text_output, image_text_output]
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ app.launch()
temp_audio.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8a873051a6c784789c314ab829772eac2446271337f54d58db48e921e81ab71e
3
+ size 710700
todogen_LLM_config.yaml CHANGED
@@ -38,4 +38,14 @@ HF_CONFIG_PATH:
38
  openai_filter:
39
  base_url_filter: https://aihubmix.com/v1
40
  api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
41
- model_filter: gpt-4o-mini
 
 
 
 
 
 
 
 
 
 
 
38
  openai_filter:
39
  base_url_filter: https://aihubmix.com/v1
40
  api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
41
+ model_filter: gpt-4o-mini
42
+
43
+ xunfei:
44
+ appid: 33c1b63d
45
+ apikey: 40bf7cd82e31ace30a9cfb76309a43a3
46
+ apisecret: OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4
47
+
48
+ azure_speech:
49
+ key: 45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ
50
+ region: eastus2
51
+ endpoint: https://eastus2.stt.speech.microsoft.com
tools.py ADDED
@@ -0,0 +1,828 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding:utf-8 -*-
3
+ import os
4
+ import datetime
5
+ import re
6
+ import time
7
+ import traceback
8
+ import math
9
+ from urllib.parse import urlparse
10
+ from urllib3 import encode_multipart_formdata
11
+ from wsgiref.handlers import format_date_time
12
+ from time import mktime
13
+ import hashlib
14
+ import base64
15
+ import hmac
16
+ from urllib.parse import urlencode
17
+ import json
18
+ import requests
19
+ import azure.cognitiveservices.speech as speechsdk
20
+
21
+ # 常量定义
22
+ LFASR_HOST = "http://upload-ost-api.xfyun.cn/file" # 文件上传Host
23
+ API_INIT = "/mpupload/init" # 初始化接口
24
+ API_UPLOAD = "/upload" # 上传接口
25
+ API_CUT = "/mpupload/upload" # 分片上传接口
26
+ API_CUT_COMPLETE = "/mpupload/complete" # 分片完成接口
27
+ API_CUT_CANCEL = "/mpupload/cancel" # 分片取消接口
28
+ FILE_PIECE_SIZE = 5242880 # 文件分片大小5M
29
+ PRO_CREATE_URI = "/v2/ost/pro_create"
30
+ QUERY_URI = "/v2/ost/query"
31
+
32
+
33
+ # 文件上传类
34
+ class FileUploader:
35
+ def __init__(self, app_id, api_key, api_secret, upload_file_path):
36
+ self.app_id = app_id
37
+ self.api_key = api_key
38
+ self.api_secret = api_secret
39
+ self.upload_file_path = upload_file_path
40
+
41
+ def get_request_id(self):
42
+ """生成请求ID"""
43
+ return time.strftime("%Y%m%d%H%M")
44
+
45
+ def hashlib_256(self, data):
46
+ """计算 SHA256 哈希"""
47
+ m = hashlib.sha256(bytes(data.encode(encoding="utf-8"))).digest()
48
+ digest = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
49
+ return digest
50
+
51
+ def assemble_auth_header(self, request_url, file_data_type, method="", body=""):
52
+ """组装鉴权头部"""
53
+ u = urlparse(request_url)
54
+ host = u.hostname
55
+ path = u.path
56
+ now = datetime.datetime.now()
57
+ date = format_date_time(mktime(now.timetuple()))
58
+ digest = "SHA256=" + self.hashlib_256("")
59
+ signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1\ndigest: {}".format(
60
+ host, date, method, path, digest
61
+ )
62
+ signature_sha = hmac.new(
63
+ self.api_secret.encode("utf-8"),
64
+ signature_origin.encode("utf-8"),
65
+ digestmod=hashlib.sha256,
66
+ ).digest()
67
+ signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
68
+ authorization = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
69
+ self.api_key,
70
+ "hmac-sha256",
71
+ "host date request-line digest",
72
+ signature_sha,
73
+ )
74
+ headers = {
75
+ "host": host,
76
+ "date": date,
77
+ "authorization": authorization,
78
+ "digest": digest,
79
+ "content-type": file_data_type,
80
+ }
81
+ return headers
82
+
83
+ def call_api(self, url, file_data, file_data_type):
84
+ """调用POST API接口"""
85
+ headers = self.assemble_auth_header(
86
+ url, file_data_type, method="POST", body=file_data
87
+ )
88
+ try:
89
+ resp = requests.post(url, headers=headers, data=file_data, timeout=8)
90
+ print("上传状态:", resp.status_code, resp.text)
91
+ return resp.json()
92
+ except Exception as e:
93
+ print("上传失败!Exception :%s" % e)
94
+ return None
95
+
96
+ def upload_cut_complete(self, upload_id):
97
+ """分块上传完成"""
98
+ body_dict = {
99
+ "app_id": self.app_id,
100
+ "request_id": self.get_request_id(),
101
+ "upload_id": upload_id,
102
+ }
103
+ file_data_type = "application/json"
104
+ url = LFASR_HOST + API_CUT_COMPLETE
105
+ response = self.call_api(url, json.dumps(body_dict), file_data_type)
106
+ if response and "data" in response and "url" in response["data"]:
107
+ file_url = response["data"]["url"]
108
+ print("任务上传结束")
109
+ return file_url
110
+ else:
111
+ print("分片上传完成失败", response)
112
+ return None
113
+
114
+ def upload_file(self):
115
+ """上传文件,根据文件大小选择分片或普通上传"""
116
+ file_total_size = os.path.getsize(self.upload_file_path)
117
+ if file_total_size < 31457280: # 30MB
118
+ print("-----不使用分块上传-----")
119
+ return self.simple_upload()
120
+ else:
121
+ print("-----使用分块上传-----")
122
+ return self.multipart_upload()
123
+
124
+ def simple_upload(self):
125
+ """简单上传文件"""
126
+ try:
127
+ with open(self.upload_file_path, mode="rb") as f:
128
+ file = {
129
+ "data": (self.upload_file_path, f.read()),
130
+ "app_id": self.app_id,
131
+ "request_id": self.get_request_id(),
132
+ }
133
+ encode_data = encode_multipart_formdata(file)
134
+ file_data = encode_data[0]
135
+ file_data_type = encode_data[1]
136
+ url = LFASR_HOST + API_UPLOAD
137
+ response = self.call_api(url, file_data, file_data_type)
138
+ if response and "data" in response and "url" in response["data"]:
139
+ return response["data"]["url"]
140
+ else:
141
+ print("简单上传失败", response)
142
+ return None
143
+ except FileNotFoundError:
144
+ print("文件未找到:", self.upload_file_path)
145
+ return None
146
+
147
+ def multipart_upload(self):
148
+ """分片上传文件"""
149
+ upload_id = self.prepare_upload()
150
+ if not upload_id:
151
+ return None
152
+
153
+ if not self.do_upload(upload_id):
154
+ return None
155
+
156
+ file_url = self.upload_cut_complete(upload_id)
157
+ print("分片上传地址:", file_url)
158
+ return file_url
159
+
160
+ def prepare_upload(self):
161
+ """预处理,获取upload_id"""
162
+ body_dict = {
163
+ "app_id": self.app_id,
164
+ "request_id": self.get_request_id(),
165
+ }
166
+ url = LFASR_HOST + API_INIT
167
+ file_data_type = "application/json"
168
+ response = self.call_api(url, json.dumps(body_dict), file_data_type)
169
+ if response and "data" in response and "upload_id" in response["data"]:
170
+ return response["data"]["upload_id"]
171
+ else:
172
+ print("预处理失败", response)
173
+ return None
174
+
175
+ def do_upload(self, upload_id):
176
+ """执行分片上传"""
177
+ file_total_size = os.path.getsize(self.upload_file_path)
178
+ chunk_size = FILE_PIECE_SIZE
179
+ chunks = math.ceil(file_total_size / chunk_size)
180
+ request_id = self.get_request_id()
181
+ slice_id = 1
182
+
183
+ print(
184
+ "文件:",
185
+ self.upload_file_path,
186
+ " 文件大小:",
187
+ file_total_size,
188
+ " 分块大小:",
189
+ chunk_size,
190
+ " 分块数:",
191
+ chunks,
192
+ )
193
+
194
+ with open(self.upload_file_path, mode="rb") as content:
195
+ while slice_id <= chunks:
196
+ current_size = min(
197
+ chunk_size, file_total_size - (slice_id - 1) * chunk_size
198
+ )
199
+
200
+ file = {
201
+ "data": (self.upload_file_path, content.read(current_size)),
202
+ "app_id": self.app_id,
203
+ "request_id": request_id,
204
+ "upload_id": upload_id,
205
+ "slice_id": slice_id,
206
+ }
207
+
208
+ encode_data = encode_multipart_formdata(file)
209
+ file_data = encode_data[0]
210
+ file_data_type = encode_data[1]
211
+ url = LFASR_HOST + API_CUT
212
+
213
+ resp = self.call_api(url, file_data, file_data_type)
214
+ count = 0
215
+ while not resp and (count < 3):
216
+ print("上传重试")
217
+ resp = self.call_api(url, file_data, file_data_type)
218
+ count = count + 1
219
+ time.sleep(1)
220
+ if not resp:
221
+ print("分片上传失败")
222
+ return False
223
+ slice_id += 1
224
+
225
+ return True
226
+
227
+
228
+ class ResultExtractor:
229
+ def __init__(self, appid, apikey, apisecret):
230
+ # POST 请求相关参数
231
+ self.Host = "ost-api.xfyun.cn"
232
+ self.RequestUriCreate = PRO_CREATE_URI
233
+ self.RequestUriQuery = QUERY_URI
234
+ # 设置 URL
235
+ if re.match(r"^\d", self.Host):
236
+ self.urlCreate = "http://" + self.Host + self.RequestUriCreate
237
+ self.urlQuery = "http://" + self.Host + self.RequestUriQuery
238
+ else:
239
+ self.urlCreate = "https://" + self.Host + self.RequestUriCreate
240
+ self.urlQuery = "https://" + self.Host + self.RequestUriQuery
241
+ self.HttpMethod = "POST"
242
+ self.APPID = appid
243
+ self.Algorithm = "hmac-sha256"
244
+ self.HttpProto = "HTTP/1.1"
245
+ self.UserName = apikey
246
+ self.Secret = apisecret
247
+
248
+ # 设置当前时间
249
+ cur_time_utc = datetime.datetime.now(datetime.timezone.utc)
250
+ self.Date = self.httpdate(cur_time_utc)
251
+
252
+ # 设置测试音频文件参数
253
+ self.BusinessArgsCreate = {
254
+ "language": "zh_cn",
255
+ "accent": "mandarin",
256
+ "domain": "pro_ost_ed",
257
+ }
258
+
259
+ def img_read(self, path):
260
+ with open(path, "rb") as fo:
261
+ return fo.read()
262
+
263
+ def hashlib_256(self, res):
264
+ m = hashlib.sha256(bytes(res.encode(encoding="utf-8"))).digest()
265
+ result = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
266
+ return result
267
+
268
+ def httpdate(self, dt):
269
+ weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()]
270
+ month = [
271
+ "Jan",
272
+ "Feb",
273
+ "Mar",
274
+ "Apr",
275
+ "May",
276
+ "Jun",
277
+ "Jul",
278
+ "Aug",
279
+ "Sep",
280
+ "Oct",
281
+ "Nov",
282
+ "Dec",
283
+ ][dt.month - 1]
284
+ return "%s, %02d %s %04d %02d:%02d:%02d GMT" % (
285
+ weekday,
286
+ dt.day,
287
+ month,
288
+ dt.year,
289
+ dt.hour,
290
+ dt.minute,
291
+ dt.second,
292
+ )
293
+
294
+ def generateSignature(self, digest, uri):
295
+ signature_str = "host: " + self.Host + "\n"
296
+ signature_str += "date: " + self.Date + "\n"
297
+ signature_str += self.HttpMethod + " " + uri + " " + self.HttpProto + "\n"
298
+ signature_str += "digest: " + digest
299
+ signature = hmac.new(
300
+ bytes(self.Secret.encode("utf-8")),
301
+ bytes(signature_str.encode("utf-8")),
302
+ digestmod=hashlib.sha256,
303
+ ).digest()
304
+ result = base64.b64encode(signature)
305
+ return result.decode(encoding="utf-8")
306
+
307
+ def init_header(self, data, uri):
308
+ digest = self.hashlib_256(data)
309
+ sign = self.generateSignature(digest, uri)
310
+ auth_header = (
311
+ 'api_key="%s",algorithm="%s", '
312
+ 'headers="host date request-line digest", '
313
+ 'signature="%s"' % (self.UserName, self.Algorithm, sign)
314
+ )
315
+ headers = {
316
+ "Content-Type": "application/json",
317
+ "Accept": "application/json",
318
+ "Method": "POST",
319
+ "Host": self.Host,
320
+ "Date": self.Date,
321
+ "Digest": digest,
322
+ "Authorization": auth_header,
323
+ }
324
+ return headers
325
+
326
+ def get_create_body(self, fileurl):
327
+ post_data = {
328
+ "common": {"app_id": self.APPID},
329
+ "business": self.BusinessArgsCreate,
330
+ "data": {"audio_src": "http", "audio_url": fileurl, "encoding": "raw"},
331
+ }
332
+ body = json.dumps(post_data)
333
+ return body
334
+
335
+ def get_query_body(self, task_id):
336
+ post_data = {
337
+ "common": {"app_id": self.APPID},
338
+ "business": {
339
+ "task_id": task_id,
340
+ },
341
+ }
342
+ body = json.dumps(post_data)
343
+ return body
344
+
345
+ def call(self, url, body, headers):
346
+ try:
347
+ response = requests.post(url, data=body, headers=headers, timeout=8)
348
+ status_code = response.status_code
349
+ if status_code != 200:
350
+ info = response.content
351
+ return info
352
+ else:
353
+ try:
354
+ return json.loads(response.text)
355
+ except json.JSONDecodeError:
356
+ return response.text
357
+ except Exception as e:
358
+ print("Exception :%s" % e)
359
+ return None
360
+
361
+ def task_create(self, fileurl):
362
+ body = self.get_create_body(fileurl)
363
+ headers_create = self.init_header(body, self.RequestUriCreate)
364
+ return self.call(self.urlCreate, body, headers_create)
365
+
366
+ def task_query(self, task_id):
367
+ query_body = self.get_query_body(task_id)
368
+ headers_query = self.init_header(query_body, self.RequestUriQuery)
369
+ return self.call(self.urlQuery, query_body, headers_query)
370
+
371
+ def extract_text(self, result):
372
+ """
373
+ 从API响应中提取文本内容
374
+ 支持多种结果格式,增强错误处理
375
+ """
376
+ # 调试输出:打印原始结果类型
377
+ print(f"\n[DEBUG] extract_text 输入类型: {type(result)}")
378
+
379
+ # 如果是字符串,尝试解析为JSON
380
+ if isinstance(result, str):
381
+ print(f"[DEBUG] 字符串内容 (前200字符): {result[:200]}")
382
+ try:
383
+ result = json.loads(result)
384
+ print("[DEBUG] 成功解析字符串为JSON对象")
385
+ except json.JSONDecodeError:
386
+ print("[DEBUG] 无法解析为JSON,返回原始字符串")
387
+ return result
388
+
389
+ # 处理字典类型的结果
390
+ if isinstance(result, dict):
391
+ print("[DEBUG] 处理字典类型结果")
392
+
393
+ # 1. 检查错误信息
394
+ if "code" in result and result["code"] != 0:
395
+ error_msg = result.get("message", "未知错误")
396
+ print(
397
+ f"[ERROR] API返回错误: code={result['code']}, message={error_msg}"
398
+ )
399
+ return f"错误: {error_msg}"
400
+
401
+ # 2. 检查直接包含文本结果的情况
402
+ if "result" in result and isinstance(result["result"], str):
403
+ print("[DEBUG] 找到直接结果字段")
404
+ return result["result"]
405
+
406
+ # 3. 检查lattice结构(详细结果)
407
+ if "lattice" in result and isinstance(result["lattice"], list):
408
+ print("[DEBUG] 解析lattice结构")
409
+ text_parts = []
410
+ for lattice in result["lattice"]:
411
+ if not isinstance(lattice, dict):
412
+ continue
413
+
414
+ # 获取json_1best内容
415
+ json_1best = lattice.get("json_1best", {})
416
+ if not json_1best or not isinstance(json_1best, dict):
417
+ continue
418
+
419
+ # 处理st字段 - 修正:st可能是字典或列表
420
+ st_content = json_1best.get("st")
421
+ st_list = []
422
+ if isinstance(st_content, dict):
423
+ st_list = [st_content] # 转为列表统一处理
424
+ elif isinstance(st_content, list):
425
+ st_list = st_content
426
+
427
+ for st in st_list:
428
+ if isinstance(st, str):
429
+ # 直接是字符串结果
430
+ text_parts.append(st)
431
+ elif isinstance(st, dict):
432
+ # 处理字典结构的st
433
+ rt = st.get("rt", [])
434
+ if not isinstance(rt, list):
435
+ continue
436
+
437
+ for item in rt:
438
+ if isinstance(item, dict):
439
+ ws_list = item.get("ws", [])
440
+ if isinstance(ws_list, list):
441
+ for ws in ws_list:
442
+ if isinstance(ws, dict):
443
+ cw_list = ws.get("cw", [])
444
+ if isinstance(cw_list, list):
445
+ for cw in cw_list:
446
+ if isinstance(cw, dict):
447
+ w = cw.get("w", "")
448
+ if w:
449
+ text_parts.append(w)
450
+ return "".join(text_parts)
451
+
452
+ # 4. 检查简化结构(直接包含st)
453
+ if "st" in result and isinstance(result["st"], list):
454
+ print("[DEBUG] 解析st结构")
455
+ text_parts = []
456
+ for st in result["st"]:
457
+ if isinstance(st, str):
458
+ text_parts.append(st)
459
+ elif isinstance(st, dict):
460
+ rt = st.get("rt", [])
461
+ if isinstance(rt, list):
462
+ for item in rt:
463
+ if isinstance(item, dict):
464
+ ws_list = item.get("ws", [])
465
+ if isinstance(ws_list, list):
466
+ for ws in ws_list:
467
+ if isinstance(ws, dict):
468
+ cw_list = ws.get("cw", [])
469
+ if isinstance(cw_list, list):
470
+ for cw in cw_list:
471
+ if isinstance(cw, dict):
472
+ w = cw.get("w", "")
473
+ if w:
474
+ text_parts.append(w)
475
+ return "".join(text_parts)
476
+
477
+ # 5. 其他未知结构
478
+ print("[WARNING] 无法识别的结果结构")
479
+ return json.dumps(result, indent=2, ensure_ascii=False)
480
+
481
+ # 6. 非字典类型结果
482
+ print(f"[WARNING] 非字典类型结果: {type(result)}")
483
+ return str(result)
484
+
485
+
486
+ def audio_to_str(appid, apikey, apisecret, file_path):
487
+ """
488
+ 调用讯飞开放平台接口,获取音频文件的转写结果。
489
+
490
+ 参数:
491
+ appid (str): 讯飞开放平台的appid。
492
+ apikey (str): 讯飞开放平台的apikey。
493
+ apisecret (str): 讯飞开放平台的apisecret。
494
+ file_path (str): 音频文件路径。
495
+
496
+ 返回值:
497
+ str: 转写结果文本,如果发生错误则返回None。
498
+ """
499
+ # 检查文件是否存在
500
+ if not os.path.exists(file_path):
501
+ print(f"错误:文件 {file_path} 不存在")
502
+ return None
503
+
504
+ try:
505
+ # 1. 文件上传
506
+ file_uploader = FileUploader(
507
+ app_id=appid,
508
+ api_key=apikey,
509
+ api_secret=apisecret,
510
+ upload_file_path=file_path,
511
+ )
512
+ fileurl = file_uploader.upload_file()
513
+ if not fileurl:
514
+ print("文件上传失败")
515
+ return None
516
+ print("文件上传成功,fileurl:", fileurl)
517
+
518
+ # 2. 创建任务并查询结果
519
+ result_extractor = ResultExtractor(appid, apikey, apisecret)
520
+ print("\n------ 创建任务 -------")
521
+ create_response = result_extractor.task_create(fileurl)
522
+
523
+ # 调试输出创建响应
524
+ print(
525
+ f"[DEBUG] 创建任务响应: {json.dumps(create_response, indent=2, ensure_ascii=False)}"
526
+ )
527
+
528
+ if not isinstance(create_response, dict) or "data" not in create_response:
529
+ print("创建任务失败:", create_response)
530
+ return None
531
+
532
+ task_id = create_response["data"]["task_id"]
533
+ print(f"任务ID: {task_id}")
534
+
535
+ # 查询任务
536
+ print("\n------ 查询任务 -------")
537
+ print("任务转写中······")
538
+ max_attempts = 30
539
+ attempt = 0
540
+
541
+ while attempt < max_attempts:
542
+ result = result_extractor.task_query(task_id)
543
+
544
+ # 调试输出查询响应
545
+ print(f"\n[QUERY {attempt + 1}] 响应类型: {type(result)}")
546
+ if isinstance(result, dict):
547
+ print(
548
+ f"[QUERY {attempt + 1}] 响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}"
549
+ )
550
+ else:
551
+ print(
552
+ f"[QUERY {attempt + 1}] 响应内容 (前200字符): {str(result)[:200]}"
553
+ )
554
+
555
+ # 检查响应是否有效
556
+ if not isinstance(result, dict):
557
+ print(f"无效响应类型: {type(result)}")
558
+ return None
559
+
560
+ # 检查API错误码
561
+ if "code" in result and result["code"] != 0:
562
+ error_msg = result.get("message", "未知错误")
563
+ print(f"API错误: code={result['code']}, message={error_msg}")
564
+ return None
565
+
566
+ # 获取任务状态
567
+ task_data = result.get("data", {})
568
+ task_status = task_data.get("task_status")
569
+
570
+ if not task_status:
571
+ print("响应中缺少任务状态字段")
572
+ print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
573
+ return None
574
+
575
+ # 处理不同状态
576
+ if task_status in ["3", "4"]: # 任务已完成或回调完成
577
+ print("转写完成···")
578
+
579
+ # 提取结果
580
+ result_content = task_data.get("result")
581
+ if result_content is not None:
582
+ try:
583
+ result_text = result_extractor.extract_text(result_content)
584
+ print("\n转写结果:\n", result_text)
585
+ return result_text
586
+ except Exception as e:
587
+ print(f"\n提取文本时出错: {str(e)}")
588
+ print(f"错误详情:\n{traceback.format_exc()}")
589
+ print(
590
+ "原始结果内容:",
591
+ json.dumps(result_content, indent=2, ensure_ascii=False),
592
+ )
593
+ return None
594
+ else:
595
+ print("\n响应中缺少结果字段")
596
+ print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
597
+ return None
598
+
599
+ elif task_status in ["1", "2"]: # 任务待处理或处理中
600
+ print(
601
+ f"任务状态:{task_status},等待中... (尝试 {attempt + 1}/{max_attempts})"
602
+ )
603
+ time.sleep(5)
604
+ attempt += 1
605
+ else:
606
+ print(f"未知任务状态:{task_status}")
607
+ print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
608
+ return None
609
+ else:
610
+ print(f"超过最大查询次数({max_attempts}),任务可能仍在处理中")
611
+ return None
612
+
613
+ except Exception as e:
614
+ print(f"发生异常: {str(e)}")
615
+ print(f"错误详情:\n{traceback.format_exc()}")
616
+ return None
617
+
618
+
619
+ """
620
+ 1、通用文字识别,图像数据base64编码后大小不得超过10M
621
+ 2、appid、apiSecret、apiKey请到讯飞开放平台控制台获取并填写到此demo中
622
+ 3、支持中英文,支持手写和印刷文字。
623
+ 4、在倾斜文字上效果有提升,同时支持部分生僻字的识别
624
+ """
625
+
626
+ # 图像识别接口地址
627
+ URL = "https://api.xf-yun.com/v1/private/sf8e6aca1"
628
+
629
+
630
+ class AssembleHeaderException(Exception):
631
+ def __init__(self, msg):
632
+ self.message = msg
633
+
634
+
635
+ class Url:
636
+ def __init__(self, host, path, schema):
637
+ self.host = host
638
+ self.path = path
639
+ self.schema = schema
640
+ pass
641
+
642
+
643
+ # calculate sha256 and encode to base64
644
+ def sha256base64(data):
645
+ sha256 = hashlib.sha256()
646
+ sha256.update(data)
647
+ digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
648
+ return digest
649
+
650
+
651
+ def parse_url(requset_url):
652
+ stidx = requset_url.index("://")
653
+ host = requset_url[stidx + 3 :]
654
+ schema = requset_url[: stidx + 3]
655
+ edidx = host.index("/")
656
+ if edidx <= 0:
657
+ raise AssembleHeaderException("invalid request url:" + requset_url)
658
+ path = host[edidx:]
659
+ host = host[:edidx]
660
+ u = Url(host, path, schema)
661
+ return u
662
+
663
+
664
+ # build websocket auth request url
665
+ def assemble_ws_auth_url(requset_url, method="POST", api_key="", api_secret=""):
666
+ u = parse_url(requset_url)
667
+ host = u.host
668
+ path = u.path
669
+ now = datetime.datetime.now()
670
+ date = format_date_time(mktime(now.timetuple()))
671
+ # print(date) # 可选:打印Date值
672
+
673
+ signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
674
+ host, date, method, path
675
+ )
676
+ # print(signature_origin) # 可选:打印签名原文
677
+ signature_sha = hmac.new(
678
+ api_secret.encode("utf-8"),
679
+ signature_origin.encode("utf-8"),
680
+ digestmod=hashlib.sha256,
681
+ ).digest()
682
+ signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
683
+ authorization_origin = (
684
+ 'api_key="%s", algorithm="%s", headers="%s", signature="%s"'
685
+ % (api_key, "hmac-sha256", "host date request-line", signature_sha)
686
+ )
687
+ authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
688
+ encoding="utf-8"
689
+ )
690
+ # print(authorization_origin) # 可选:打印鉴权原文
691
+ values = {"host": host, "date": date, "authorization": authorization}
692
+
693
+ return requset_url + "?" + urlencode(values)
694
+
695
+
696
+ def image_to_str(endpoint=None, key=None, unused_param=None, file_path=None):
697
+ """
698
+ 调用Azure Computer Vision API识别图片中的文字。
699
+
700
+ 参数:
701
+ endpoint (str): Azure Computer Vision endpoint URL。
702
+ key (str): Azure Computer Vision API key。
703
+ unused_param (str): 未使用的参数,保持兼容性。
704
+ file_path (str): 图片文件路径。
705
+
706
+ 返回值:
707
+ str: 图片中的文字识别结果,如果发生错误则返回None。
708
+ """
709
+
710
+ # 默认配置
711
+ if endpoint is None:
712
+ endpoint = "https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/"
713
+ if key is None:
714
+ key = "45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ"
715
+
716
+ try:
717
+ # 读取图片文件
718
+ with open(file_path, "rb") as f:
719
+ image_data = f.read()
720
+
721
+ # 构造请求URL
722
+ analyze_url = endpoint.rstrip('/') + "/vision/v3.2/read/analyze"
723
+
724
+ # 设置请求头
725
+ headers = {
726
+ 'Ocp-Apim-Subscription-Key': key,
727
+ 'Content-Type': 'application/octet-stream'
728
+ }
729
+
730
+ # 发送POST请求开始分析
731
+ response = requests.post(analyze_url, headers=headers, data=image_data)
732
+
733
+ if response.status_code != 202:
734
+ print(f"分析请求失败: {response.status_code}, {response.text}")
735
+ return None
736
+
737
+ # 获取操作位置
738
+ operation_url = response.headers["Operation-Location"]
739
+
740
+ # 轮询结果
741
+ import time
742
+ while True:
743
+ result_response = requests.get(operation_url, headers={'Ocp-Apim-Subscription-Key': key})
744
+ result = result_response.json()
745
+
746
+ if result["status"] == "succeeded":
747
+ # 提取文字
748
+ text_results = []
749
+ if "analyzeResult" in result and "readResults" in result["analyzeResult"]:
750
+ for read_result in result["analyzeResult"]["readResults"]:
751
+ for line in read_result["lines"]:
752
+ text_results.append(line["text"])
753
+
754
+ return " ".join(text_results) if text_results else ""
755
+
756
+ elif result["status"] == "failed":
757
+ print(f"文字识别失败: {result}")
758
+ return None
759
+
760
+ # 等待1秒后重试
761
+ time.sleep(1)
762
+
763
+ except Exception as e:
764
+ print(f"发生异常: {e}")
765
+ return None
766
+
767
+
768
+ if __name__ == "__main__":
769
+ # 输入讯飞开放平台的 appid,secret、key 和文件路径
770
+ appid = "33c1b63d"
771
+ apikey = "40bf7cd82e31ace30a9cfb76309a43a3"
772
+ apisecret = "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4"
773
+ audio_path = r"audio_sample_little.wav" # 确保文件路径正确
774
+ image_path = r"1.png" # 确保文件路径正确
775
+
776
+ # 音频转文字
777
+ audio_text = audio_to_str(appid, apikey, apisecret, audio_path)
778
+ # 图片转文字
779
+ image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_path)
780
+
781
+ print("-"* 20)
782
+
783
+ print("\n音频转文字结果:", audio_text)
784
+ print("\n图片转文字结果:", image_text)
785
+
786
+
787
+ def azure_speech_to_text(speech_key, speech_region, audio_file_path):
788
+ """
789
+ 使用Azure Speech服务将音频文件转换为文本。
790
+
791
+ 参数:
792
+ speech_key (str): Azure Speech服务的API密钥。
793
+ speech_region (str): Azure Speech服务的区域。
794
+ audio_file_path (str): 音频文件路径。
795
+
796
+ 返回值:
797
+ str: 转换后的文本,如果发生错误则返回None。
798
+ """
799
+ try:
800
+ # 设置语音配置
801
+ speech_config = speechsdk.SpeechConfig(subscription=speech_key, region=speech_region)
802
+ speech_config.speech_recognition_language = "zh-CN" # 设置为中文
803
+
804
+ # 设置音频配置
805
+ audio_config = speechsdk.audio.AudioConfig(filename=audio_file_path)
806
+
807
+ # 创建语音识别器
808
+ speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
809
+
810
+ # 执行语音识别
811
+ result = speech_recognizer.recognize_once()
812
+
813
+ # 检查识别结果
814
+ if result.reason == speechsdk.ResultReason.RecognizedSpeech:
815
+ print(f"Azure Speech识别成功: {result.text}")
816
+ return result.text
817
+ elif result.reason == speechsdk.ResultReason.NoMatch:
818
+ print("Azure Speech未识别到语音")
819
+ return None
820
+ elif result.reason == speechsdk.ResultReason.Canceled:
821
+ cancellation_details = result.cancellation_details
822
+ print(f"Azure Speech识别被取消: {cancellation_details.reason}")
823
+ if cancellation_details.reason == speechsdk.CancellationReason.Error:
824
+ print(f"错误详情: {cancellation_details.error_details}")
825
+ return None
826
+ except Exception as e:
827
+ print(f"Azure Speech识别出错: {str(e)}")
828
+ return None