mvp
Browse files- .gitattributes +2 -0
- 1.png +0 -0
- README.md +15 -15
- __init__.py +0 -0
- app.py +1330 -722
- app_pro.py +840 -0
- audio_127.0.0.1.wav +3 -0
- image_127.0.0.1.jpg +0 -0
- requirements.txt +8 -4
- se_app.py +232 -0
- temp_audio.wav +3 -0
- todogen_LLM_config.yaml +11 -1
- tools.py +828 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
audio_127.0.0.1.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
+
temp_audio.wav filter=lfs diff=lfs merge=lfs -text
|
1.png
ADDED
![]() |
README.md
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
---
|
2 |
-
title: ToDoAgent
|
3 |
-
emoji: 💬
|
4 |
-
colorFrom: yellow
|
5 |
-
colorTo: purple
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.32.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: bsd
|
11 |
-
short_description: AI Agent filters, creates to-do list and reminds smartly
|
12 |
-
tags: ['agent-demo-track']
|
13 |
-
demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
|
14 |
-
---
|
15 |
-
|
16 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
|
|
1 |
+
---
|
2 |
+
title: ToDoAgent
|
3 |
+
emoji: 💬
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.32.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: bsd
|
11 |
+
short_description: AI Agent filters, creates to-do list and reminds smartly
|
12 |
+
tags: ['agent-demo-track']
|
13 |
+
demo: https://youtu.be/S-wh3Psx15M?si=Wiq7EzmE3dmBvLKQ
|
14 |
+
---
|
15 |
+
|
16 |
An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
|
__init__.py
ADDED
File without changes
|
app.py
CHANGED
@@ -1,722 +1,1330 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
import json
|
3 |
-
from pathlib import Path
|
4 |
-
import yaml
|
5 |
-
import re
|
6 |
-
import logging
|
7 |
-
import io
|
8 |
-
import sys
|
9 |
-
import
|
10 |
-
|
11 |
-
import
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
def
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
try:
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
return
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
#
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
}
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
return
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
|
536 |
-
|
537 |
-
|
538 |
-
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
|
555 |
-
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
633 |
-
|
634 |
-
|
635 |
-
|
636 |
-
|
637 |
-
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
if not
|
658 |
-
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
685 |
-
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
import yaml
|
5 |
+
import re
|
6 |
+
import logging
|
7 |
+
import io
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
from datetime import datetime, timezone, timedelta
|
12 |
+
import requests
|
13 |
+
|
14 |
+
from tools import FileUploader, ResultExtractor, audio_to_str, image_to_str, azure_speech_to_text #gege的多模态
|
15 |
+
import numpy as np
|
16 |
+
from scipy.io.wavfile import write as write_wav
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
# 指定保存文件的相对路径
|
20 |
+
SAVE_DIR = 'download' # 相对路径
|
21 |
+
os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
|
22 |
+
|
23 |
+
def save_audio(audio, filename):
|
24 |
+
"""保存音频为.wav文件"""
|
25 |
+
sample_rate, audio_data = audio
|
26 |
+
write_wav(filename, sample_rate, audio_data)
|
27 |
+
|
28 |
+
def save_image(image, filename):
|
29 |
+
"""保存图片为.jpg文件"""
|
30 |
+
img = Image.fromarray(image.astype('uint8'))
|
31 |
+
img.save(filename)
|
32 |
+
|
33 |
+
# --- IP获取功能 (从 se_app.py 迁移) ---
|
34 |
+
def get_client_ip(request: gr.Request, debug_mode=False):
|
35 |
+
"""获取客户端真实IP地址"""
|
36 |
+
if request:
|
37 |
+
# 从请求头中获取真实IP(考虑代理情况)
|
38 |
+
x_forwarded_for = request.headers.get("x-forwarded-for", "")
|
39 |
+
if x_forwarded_for:
|
40 |
+
client_ip = x_forwarded_for.split(",")[0]
|
41 |
+
else:
|
42 |
+
client_ip = request.client.host
|
43 |
+
if debug_mode:
|
44 |
+
print(f"Debug: Client IP detected as {client_ip}")
|
45 |
+
return client_ip
|
46 |
+
return "unknown"
|
47 |
+
|
48 |
+
# --- 配置加载 (从 config_loader.py 迁移并简化) ---
|
49 |
+
CONFIG = None
|
50 |
+
HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
|
51 |
+
|
52 |
+
def load_hf_config():
|
53 |
+
global CONFIG
|
54 |
+
if CONFIG is None:
|
55 |
+
try:
|
56 |
+
with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
57 |
+
CONFIG = yaml.safe_load(f)
|
58 |
+
print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
|
59 |
+
except FileNotFoundError:
|
60 |
+
print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
|
61 |
+
CONFIG = {} # 提供一个空配置以避免后续错误
|
62 |
+
except Exception as e:
|
63 |
+
print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
|
64 |
+
CONFIG = {}
|
65 |
+
return CONFIG
|
66 |
+
|
67 |
+
def get_hf_openai_config():
|
68 |
+
config = load_hf_config()
|
69 |
+
return config.get('openai', {})
|
70 |
+
|
71 |
+
def get_hf_openai_filter_config():
|
72 |
+
config = load_hf_config()
|
73 |
+
return config.get('openai_filter', {})
|
74 |
+
|
75 |
+
def get_hf_xunfei_config():
|
76 |
+
config = load_hf_config()
|
77 |
+
return config.get('xunfei', {})
|
78 |
+
|
79 |
+
def get_hf_azure_speech_config():
|
80 |
+
config = load_hf_config()
|
81 |
+
return config.get('azure_speech', {})
|
82 |
+
|
83 |
+
def get_hf_paths_config():
|
84 |
+
config = load_hf_config()
|
85 |
+
# 在hf环境下,路径相对于hf目录
|
86 |
+
base = Path(__file__).resolve().parent
|
87 |
+
paths_cfg = config.get('paths', {})
|
88 |
+
return {
|
89 |
+
'base_dir': base,
|
90 |
+
'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
|
91 |
+
'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
|
92 |
+
'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
|
93 |
+
# data_dir 和 logging_dir 在 app.py 中可能用途不大,除非需要保存 LLM 输出
|
94 |
+
}
|
95 |
+
|
96 |
+
# --- LLM Client 初始化 (使用 NVIDIA API) ---
|
97 |
+
# 从配置加载 NVIDIA API 的 base_url, api_key 和 model
|
98 |
+
llm_config = get_hf_openai_config()
|
99 |
+
NVIDIA_API_BASE_URL = llm_config.get('base_url')
|
100 |
+
NVIDIA_API_KEY = llm_config.get('api_key')
|
101 |
+
NVIDIA_MODEL_NAME = llm_config.get('model')
|
102 |
+
|
103 |
+
# 从配置加载 Filter API 的 base_url, api_key 和 model
|
104 |
+
filter_config = get_hf_openai_filter_config()
|
105 |
+
Filter_API_BASE_URL = filter_config.get('base_url_filter')
|
106 |
+
Filter_API_KEY = filter_config.get('api_key_filter')
|
107 |
+
Filter_MODEL_NAME = filter_config.get('model_filter')
|
108 |
+
|
109 |
+
|
110 |
+
if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
|
111 |
+
print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
|
112 |
+
# 提供默认值或退出,以便程序可以继续运行,但LLM调用会失败
|
113 |
+
NVIDIA_API_BASE_URL = ""
|
114 |
+
NVIDIA_API_KEY = ""
|
115 |
+
NVIDIA_MODEL_NAME = ""
|
116 |
+
|
117 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
118 |
+
print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
|
119 |
+
# 提供默认值或退出,以便程序可以继续运行,但Filter LLM调用会失败
|
120 |
+
Filter_API_BASE_URL = ""
|
121 |
+
Filter_API_KEY = ""
|
122 |
+
Filter_MODEL_NAME = ""
|
123 |
+
|
124 |
+
# --- 日志配置 (简化版) ---
|
125 |
+
# 修正后的标准流编码设置 (如果需要,但 Gradio 通常处理自己的输出)
|
126 |
+
# sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8')
|
127 |
+
# sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True)
|
128 |
+
# sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', write_through=True)
|
129 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
130 |
+
logger = logging.getLogger(__name__)
|
131 |
+
|
132 |
+
# --- Prompt 和 Few-Shot 加载 (从 todogen_llm.py 迁移并适配) ---
|
133 |
+
def load_single_few_shot_file_hf(file_path: Path) -> str:
|
134 |
+
"""加载单个 few-shot 文件并转义 { 和 }"""
|
135 |
+
try:
|
136 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
137 |
+
content = f.read()
|
138 |
+
escaped_content = content.replace('{', '{{').replace('}', '}}')
|
139 |
+
logger.info(f"✅ 成功加载并转义文件: {file_path}")
|
140 |
+
return escaped_content
|
141 |
+
except FileNotFoundError:
|
142 |
+
logger.warning(f"⚠️ 警告:找不到文件 {file_path}。")
|
143 |
+
return ""
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"❌ 加载文件 {file_path} 时出错: {e}", exc_info=True)
|
146 |
+
return ""
|
147 |
+
|
148 |
+
PROMPT_TEMPLATE_CONTENT = ""
|
149 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = ""
|
150 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = ""
|
151 |
+
|
152 |
+
def load_prompt_data_hf():
|
153 |
+
global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
|
154 |
+
paths = get_hf_paths_config()
|
155 |
+
try:
|
156 |
+
with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
|
157 |
+
PROMPT_TEMPLATE_CONTENT = f.read()
|
158 |
+
logger.info(f"✅ 成功加载 Prompt 模板文件: {paths['prompt_template']}")
|
159 |
+
except FileNotFoundError:
|
160 |
+
logger.error(f"❌ 错误:找不到 Prompt 模板文件:{paths['prompt_template']}")
|
161 |
+
PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
|
162 |
+
|
163 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
|
164 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
|
165 |
+
|
166 |
+
# 应用启动时加载 prompts
|
167 |
+
load_prompt_data_hf()
|
168 |
+
|
169 |
+
# --- JSON 解析器 (从 todogen_llm.py 迁移) ---
|
170 |
+
def json_parser(text: str) -> dict:
|
171 |
+
# 改进的JSON解析器,更健壮地处理各种格式
|
172 |
+
logger.info(f"Attempting to parse: {text[:200]}...")
|
173 |
+
try:
|
174 |
+
# 1. 尝试直接将整个文本作为JSON解析
|
175 |
+
try:
|
176 |
+
parsed_data = json.loads(text)
|
177 |
+
# 使用_process_parsed_json处理解析结果
|
178 |
+
return _process_parsed_json(parsed_data)
|
179 |
+
except json.JSONDecodeError:
|
180 |
+
pass # 如果直接解析失败,继续尝试提取代码块
|
181 |
+
|
182 |
+
# 2. 尝试从 ```json ... ``` 代码块中提取和解析
|
183 |
+
match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
|
184 |
+
if match:
|
185 |
+
json_str = match.group(1).strip()
|
186 |
+
# 修复常见的JSON格式问题
|
187 |
+
json_str = re.sub(r',\s*]', ']', json_str)
|
188 |
+
json_str = re.sub(r',\s*}', '}', json_str)
|
189 |
+
try:
|
190 |
+
parsed_data = json.loads(json_str)
|
191 |
+
# 使用_process_parsed_json处理解析结果
|
192 |
+
return _process_parsed_json(parsed_data)
|
193 |
+
except json.JSONDecodeError as e_block:
|
194 |
+
logger.warning(f"JSONDecodeError from code block: {e_block} while parsing: {json_str[:200]}")
|
195 |
+
# 如果从代码块解析也失败,则继续
|
196 |
+
|
197 |
+
# 3. 尝试查找最外层的 '{...}' 或 '[...]' 作为JSON
|
198 |
+
# 先尝试查找数组格式 [...]
|
199 |
+
array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
|
200 |
+
if array_match:
|
201 |
+
potential_json = array_match.group(0).strip()
|
202 |
+
try:
|
203 |
+
parsed_data = json.loads(potential_json)
|
204 |
+
# 使用_process_parsed_json处理解析结果
|
205 |
+
return _process_parsed_json(parsed_data)
|
206 |
+
except json.JSONDecodeError:
|
207 |
+
logger.warning(f"Could not parse potential JSON array: {potential_json[:200]}")
|
208 |
+
pass
|
209 |
+
|
210 |
+
# 再尝试查找单个对象格式 {...}
|
211 |
+
object_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
212 |
+
if object_match:
|
213 |
+
potential_json = object_match.group(0).strip()
|
214 |
+
try:
|
215 |
+
parsed_data = json.loads(potential_json)
|
216 |
+
# 使用_process_parsed_json处理解析结果
|
217 |
+
return _process_parsed_json(parsed_data)
|
218 |
+
except json.JSONDecodeError:
|
219 |
+
logger.warning(f"Could not parse potential JSON object: {potential_json[:200]}")
|
220 |
+
pass
|
221 |
+
|
222 |
+
# 4. 如果所有尝试都失败,返回错误信息
|
223 |
+
logger.error(f"Failed to find or parse JSON block in text: {text[:500]}") # 增加日志长度
|
224 |
+
return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
|
225 |
+
|
226 |
+
except Exception as e: # 捕获所有其他意外错误
|
227 |
+
logger.error(f"Unexpected error in json_parser: {e} for text: {text[:200]}", exc_info=True)
|
228 |
+
return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
|
229 |
+
|
230 |
+
def _process_parsed_json(parsed_data):
|
231 |
+
"""处理解析后的JSON数据,确保返回有效的数据结构"""
|
232 |
+
try:
|
233 |
+
# 如果解析结果是空列表,���回包含空字典的列表
|
234 |
+
if isinstance(parsed_data, list):
|
235 |
+
if not parsed_data:
|
236 |
+
logger.warning("JSON解析结果为空列表,返回包含空字典的列表")
|
237 |
+
return [{}]
|
238 |
+
|
239 |
+
# 确保列表中的每个元素都是字典
|
240 |
+
processed_list = []
|
241 |
+
for item in parsed_data:
|
242 |
+
if isinstance(item, dict):
|
243 |
+
processed_list.append(item)
|
244 |
+
else:
|
245 |
+
# 如果不是字典,将其转换为字典
|
246 |
+
try:
|
247 |
+
processed_list.append({"content": str(item)})
|
248 |
+
except:
|
249 |
+
processed_list.append({"content": "无法转换的项目"})
|
250 |
+
|
251 |
+
# 如果处理后的列表为空,返回包含空字典的列表
|
252 |
+
if not processed_list:
|
253 |
+
logger.warning("处理后的JSON列表为空,返回包含空字典的列表")
|
254 |
+
return [{}]
|
255 |
+
|
256 |
+
return processed_list
|
257 |
+
|
258 |
+
# 如果是字典,直接返回
|
259 |
+
elif isinstance(parsed_data, dict):
|
260 |
+
return parsed_data
|
261 |
+
|
262 |
+
# 如果是其他类型,转换为字典
|
263 |
+
else:
|
264 |
+
logger.warning(f"JSON解析结果不是列表或字典,而是{type(parsed_data)},转换为字典")
|
265 |
+
return {"content": str(parsed_data)}
|
266 |
+
|
267 |
+
except Exception as e:
|
268 |
+
logger.error(f"处理解析后的JSON数据时出错: {e}")
|
269 |
+
return {"error": f"Error processing parsed JSON: {e}"}
|
270 |
+
|
271 |
+
# --- Filter 模块的 System Prompt (从 filter_message/libs.py 迁移) ---
|
272 |
+
FILTER_SYSTEM_PROMPT = """
|
273 |
+
# 角色
|
274 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
275 |
+
|
276 |
+
# 任务
|
277 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
278 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
279 |
+
|
280 |
+
# 要求
|
281 |
+
1. 以json格式输出
|
282 |
+
2. content简洁提炼关键词,字符数<20以内
|
283 |
+
3. 输入条数和输出条数完全一样
|
284 |
+
|
285 |
+
# 输出示例
|
286 |
+
```
|
287 |
+
[
|
288 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
289 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议"}
|
290 |
+
]
|
291 |
+
|
292 |
+
```
|
293 |
+
"""
|
294 |
+
|
295 |
+
# --- Filter 核心逻辑 (从ToDoAgent集成) ---
|
296 |
+
def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
|
297 |
+
logger.info(f"调用 filter_message_with_llm 处理输入: {text_input} (msg_id: {message_id})")
|
298 |
+
|
299 |
+
# 构造发送给 LLM 的消息
|
300 |
+
# filter 模块的 send_llm_with_prompt 接收的是 tuple[tuple] 格式的数据
|
301 |
+
# 这里我们只有一个文本输入,需要模拟成那种格式
|
302 |
+
mock_data = [(text_input, message_id)]
|
303 |
+
|
304 |
+
# 使用与ToDoAgent相同的system prompt
|
305 |
+
system_prompt = """
|
306 |
+
# 角色
|
307 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
308 |
+
|
309 |
+
# 任务
|
310 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
311 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
312 |
+
|
313 |
+
# 要求
|
314 |
+
1. 以json格式输出
|
315 |
+
2. content简洁提炼关键词,字符数<20以内
|
316 |
+
3. 输入条数和输出条数完全一样
|
317 |
+
|
318 |
+
# 输出示例
|
319 |
+
```
|
320 |
+
[
|
321 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
322 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
|
323 |
+
]
|
324 |
+
```
|
325 |
+
"""
|
326 |
+
|
327 |
+
llm_messages = [
|
328 |
+
{"role": "system", "content": system_prompt},
|
329 |
+
{"role": "user", "content": str(mock_data)}
|
330 |
+
]
|
331 |
+
|
332 |
+
try:
|
333 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
334 |
+
logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
|
335 |
+
return [{"error": "Filter API configuration incomplete", "-": "-"}]
|
336 |
+
|
337 |
+
headers = {
|
338 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
339 |
+
"Accept": "application/json"
|
340 |
+
}
|
341 |
+
payload = {
|
342 |
+
"model": Filter_MODEL_NAME,
|
343 |
+
"messages": llm_messages,
|
344 |
+
"temperature": 0.0, # 为提高准确率,温度为0(与ToDoAgent一致)
|
345 |
+
"top_p": 0.95,
|
346 |
+
"max_tokens": 1024,
|
347 |
+
"stream": False
|
348 |
+
}
|
349 |
+
|
350 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
351 |
+
|
352 |
+
try:
|
353 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
354 |
+
response.raise_for_status() # 检查 HTTP 错误
|
355 |
+
raw_llm_response = response.json()["choices"][0]["message"]["content"]
|
356 |
+
logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
357 |
+
except requests.exceptions.RequestException as e:
|
358 |
+
logger.error(f"调用 Filter API 失败: {e}")
|
359 |
+
return [{"error": f"Filter API call failed: {e}", "-": "-"}]
|
360 |
+
logger.info(f"Filter LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
361 |
+
|
362 |
+
# 解析 LLM 响应
|
363 |
+
# 移除可能的代码块标记
|
364 |
+
raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
|
365 |
+
parsed_filter_data = json_parser(raw_llm_response)
|
366 |
+
|
367 |
+
if "error" in parsed_filter_data:
|
368 |
+
logger.error(f"解析 Filter LLM 响应失败: {parsed_filter_data['error']}")
|
369 |
+
return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
|
370 |
+
|
371 |
+
# 返回解析后的数据
|
372 |
+
if isinstance(parsed_filter_data, list) and parsed_filter_data:
|
373 |
+
# 应用规则:如果分类是欠费缴纳且内容包含"缴费支出",归类为"其他"
|
374 |
+
for item in parsed_filter_data:
|
375 |
+
if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
|
376 |
+
item["分类"] = "其他"
|
377 |
+
|
378 |
+
# 检查是否有遗漏的消息ID(ToDoAgent的补充逻辑)
|
379 |
+
request_id_list = {message_id}
|
380 |
+
response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
|
381 |
+
diff = request_id_list - response_id_list
|
382 |
+
|
383 |
+
if diff:
|
384 |
+
logger.warning(f"Filter LLM 响应中有遗漏的消息ID: {diff}")
|
385 |
+
# 对于遗漏的消息,添加一个默认分类为"其他"的项
|
386 |
+
for missed_id in diff:
|
387 |
+
parsed_filter_data.append({
|
388 |
+
"message_id": missed_id,
|
389 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
390 |
+
"物流取件": 0,
|
391 |
+
"欠费缴纳": 0,
|
392 |
+
"待付(还)款": 0,
|
393 |
+
"会议邀约": 0,
|
394 |
+
"其他": 100,
|
395 |
+
"分类": "其他"
|
396 |
+
})
|
397 |
+
|
398 |
+
return parsed_filter_data
|
399 |
+
else:
|
400 |
+
logger.warning(f"Filter LLM 返回空列表或非预期格式: {parsed_filter_data}")
|
401 |
+
# 返回默认分类为"其他"的项
|
402 |
+
return [{
|
403 |
+
"message_id": message_id,
|
404 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
405 |
+
"物流取件": 0,
|
406 |
+
"欠费缴纳": 0,
|
407 |
+
"待付(还)款": 0,
|
408 |
+
"会议邀约": 0,
|
409 |
+
"其他": 100,
|
410 |
+
"分类": "其他",
|
411 |
+
"error": "Filter LLM returned empty or unexpected format"
|
412 |
+
}]
|
413 |
+
|
414 |
+
except Exception as e:
|
415 |
+
logger.exception(f"调用 Filter LLM 或解析时发生错误 (filter_message_with_llm)")
|
416 |
+
return [{
|
417 |
+
"message_id": message_id,
|
418 |
+
"content": text_input[:20], # 截取前20个字符作为content
|
419 |
+
"物流取件": 0,
|
420 |
+
"欠费缴纳": 0,
|
421 |
+
"待付(还)款": 0,
|
422 |
+
"会议邀约": 0,
|
423 |
+
"其他": 100,
|
424 |
+
"分类": "其他",
|
425 |
+
"error": f"Filter LLM call/parse error: {str(e)}"
|
426 |
+
}]
|
427 |
+
|
428 |
+
# --- ToDo List 生成核心逻辑 (使用迁移的代码) ---
|
429 |
+
def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
|
430 |
+
"""根据输入文本生成 ToDoList (使用迁移的逻辑)"""
|
431 |
+
logger.info(f"调用 generate_todolist_from_text 处理输入: {text_input} (msg_id: {message_id})")
|
432 |
+
|
433 |
+
if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
|
434 |
+
logger.error("Prompt 模板未正确加载,无法生成 ToDoList。")
|
435 |
+
return [["error", "Prompt template not loaded", "-"]]
|
436 |
+
|
437 |
+
current_time_iso = datetime.now(timezone.utc).isoformat()
|
438 |
+
# 转义输入内容中的 { 和 }
|
439 |
+
content_escaped = text_input.replace('{', '{{').replace('}', '}}')
|
440 |
+
|
441 |
+
# 构造 prompt
|
442 |
+
formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
|
443 |
+
true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
|
444 |
+
false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
|
445 |
+
current_time=current_time_iso,
|
446 |
+
message_id=message_id,
|
447 |
+
content_escaped=content_escaped
|
448 |
+
)
|
449 |
+
|
450 |
+
# 添加明确的JSON输出指令
|
451 |
+
enhanced_prompt = formatted_prompt + """
|
452 |
+
|
453 |
+
# 重要提示
|
454 |
+
请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
|
455 |
+
你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
|
456 |
+
"""
|
457 |
+
|
458 |
+
# 构造发送给 LLM 的消息
|
459 |
+
llm_messages = [
|
460 |
+
{"role": "user", "content": enhanced_prompt}
|
461 |
+
]
|
462 |
+
|
463 |
+
logger.info(f"发送给 LLM 的消息 (部分): {str(llm_messages)[:300]}...")
|
464 |
+
|
465 |
+
try:
|
466 |
+
# 根据输入文本智能生成 ToDo List
|
467 |
+
# 如果是移动话费充值提醒类消息
|
468 |
+
if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
|
469 |
+
# 直接生成待办事项,不调用API
|
470 |
+
todo_item = {
|
471 |
+
message_id: {
|
472 |
+
"is_todo": True,
|
473 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
|
474 |
+
"location": "线上:中国移动APP",
|
475 |
+
"todo_content": "缴纳话费",
|
476 |
+
"urgency": "important"
|
477 |
+
}
|
478 |
+
}
|
479 |
+
|
480 |
+
# 转换为表格显示格式 - 合并为一行
|
481 |
+
todo_content = "缴纳话费"
|
482 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
483 |
+
location = todo_item[message_id]["location"]
|
484 |
+
|
485 |
+
# 合并所有信息到任务内容中
|
486 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
487 |
+
|
488 |
+
output_for_df = []
|
489 |
+
output_for_df.append([1, combined_content, "重要"])
|
490 |
+
|
491 |
+
return output_for_df
|
492 |
+
|
493 |
+
# 如果是会议邀约类消息
|
494 |
+
elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
|
495 |
+
# 提取可能的会议时间
|
496 |
+
meeting_time = None
|
497 |
+
meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
|
498 |
+
meeting_match = re.search(meeting_pattern, text_input)
|
499 |
+
|
500 |
+
if meeting_match:
|
501 |
+
# 简单处理,实际应用中应该更精确地解析日期时间
|
502 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
|
503 |
+
else:
|
504 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
|
505 |
+
|
506 |
+
todo_item = {
|
507 |
+
message_id: {
|
508 |
+
"is_todo": True,
|
509 |
+
"end_time": meeting_time,
|
510 |
+
"location": "线上:会议软件",
|
511 |
+
"todo_content": "参加会议",
|
512 |
+
"urgency": "important"
|
513 |
+
}
|
514 |
+
}
|
515 |
+
|
516 |
+
# 转换为表格显示格式 - 合并为一行
|
517 |
+
todo_content = "参加会议"
|
518 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
519 |
+
location = todo_item[message_id]["location"]
|
520 |
+
|
521 |
+
# 合并所有信息到任务内容中
|
522 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
523 |
+
|
524 |
+
output_for_df = []
|
525 |
+
output_for_df.append([1, combined_content, "重要"])
|
526 |
+
|
527 |
+
return output_for_df
|
528 |
+
|
529 |
+
# 如果是物流取件类消息
|
530 |
+
elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
|
531 |
+
# 提取可能的取件码
|
532 |
+
pickup_code = None
|
533 |
+
code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
|
534 |
+
code_match = re.search(code_pattern, text_input)
|
535 |
+
|
536 |
+
todo_content = "取快递"
|
537 |
+
if code_match:
|
538 |
+
pickup_code = code_match.group(1)
|
539 |
+
todo_content = f"取快递(取件码:{pickup_code})"
|
540 |
+
|
541 |
+
todo_item = {
|
542 |
+
message_id: {
|
543 |
+
"is_todo": True,
|
544 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
|
545 |
+
"location": "线下:快递柜",
|
546 |
+
"todo_content": todo_content,
|
547 |
+
"urgency": "important"
|
548 |
+
}
|
549 |
+
}
|
550 |
+
|
551 |
+
# 转换为表格显示格式 - 合并为一行
|
552 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
553 |
+
location = todo_item[message_id]["location"]
|
554 |
+
|
555 |
+
# 合并所有信息到任务内容中
|
556 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地���: {location})"
|
557 |
+
|
558 |
+
output_for_df = []
|
559 |
+
output_for_df.append([1, combined_content, "重要"])
|
560 |
+
|
561 |
+
return output_for_df
|
562 |
+
|
563 |
+
# 对于其他类型的消息,调用LLM API进行处理
|
564 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
565 |
+
logger.error("Filter API 配置不完整,无法调用 Filter LLM。")
|
566 |
+
return [["error", "Filter API configuration incomplete", "-"]]
|
567 |
+
|
568 |
+
headers = {
|
569 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
570 |
+
"Accept": "application/json"
|
571 |
+
}
|
572 |
+
payload = {
|
573 |
+
"model": Filter_MODEL_NAME,
|
574 |
+
"messages": llm_messages,
|
575 |
+
"temperature": 0.2, # 降低温度以提高一致性
|
576 |
+
"top_p": 0.95,
|
577 |
+
"max_tokens": 1024,
|
578 |
+
"stream": False
|
579 |
+
}
|
580 |
+
|
581 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
582 |
+
|
583 |
+
try:
|
584 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
585 |
+
response.raise_for_status() # 检查 HTTP 错误
|
586 |
+
raw_llm_response = response.json()['choices'][0]['message']['content']
|
587 |
+
logger.info(f"LLM 原始回复 (部分): {raw_llm_response[:200]}...")
|
588 |
+
except requests.exceptions.RequestException as e:
|
589 |
+
logger.error(f"调用 Filter API 失败: {e}")
|
590 |
+
return [["error", f"Filter API call failed: {e}", "-"]]
|
591 |
+
|
592 |
+
# 解析 LLM 响应
|
593 |
+
parsed_todos_data = json_parser(raw_llm_response)
|
594 |
+
|
595 |
+
if "error" in parsed_todos_data:
|
596 |
+
logger.error(f"解析 LLM 响应失败: {parsed_todos_data['error']}")
|
597 |
+
return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
|
598 |
+
|
599 |
+
# 处理解析后的数据
|
600 |
+
output_for_df = []
|
601 |
+
|
602 |
+
# 如果是字典格式(符合prompt模板输出格式)
|
603 |
+
if isinstance(parsed_todos_data, dict):
|
604 |
+
# 获取消息ID对应的待办信息
|
605 |
+
todo_info = None
|
606 |
+
for key, value in parsed_todos_data.items():
|
607 |
+
if key == message_id or key == str(message_id):
|
608 |
+
todo_info = value
|
609 |
+
break
|
610 |
+
|
611 |
+
if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
|
612 |
+
# 提取待办信息
|
613 |
+
todo_content = todo_info.get("todo_content", "未指定待办内容")
|
614 |
+
end_time = todo_info.get("end_time")
|
615 |
+
location = todo_info.get("location")
|
616 |
+
urgency = todo_info.get("urgency", "unimportant")
|
617 |
+
|
618 |
+
# 准备合并显示的内容
|
619 |
+
combined_content = todo_content
|
620 |
+
|
621 |
+
# 添加截止时间
|
622 |
+
if end_time and end_time != "null":
|
623 |
+
try:
|
624 |
+
date_part = end_time.split("T")[0] if "T" in end_time else end_time
|
625 |
+
combined_content += f" (截止时间: {date_part}"
|
626 |
+
except:
|
627 |
+
combined_content += f" (截止时间: {end_time}"
|
628 |
+
else:
|
629 |
+
combined_content += " ("
|
630 |
+
|
631 |
+
# 添加地点
|
632 |
+
if location and location != "null":
|
633 |
+
combined_content += f", 地点: {location})"
|
634 |
+
else:
|
635 |
+
combined_content += ")"
|
636 |
+
|
637 |
+
# 添加紧急程度
|
638 |
+
urgency_display = "一般"
|
639 |
+
if urgency == "urgent":
|
640 |
+
urgency_display = "紧急"
|
641 |
+
elif urgency == "important":
|
642 |
+
urgency_display = "重要"
|
643 |
+
|
644 |
+
# 创建单行输出
|
645 |
+
output_for_df = []
|
646 |
+
output_for_df.append([1, combined_content, urgency_display])
|
647 |
+
else:
|
648 |
+
# 不是待办事项
|
649 |
+
output_for_df = []
|
650 |
+
output_for_df.append([1, "此消息不包含待办事项", "-"])
|
651 |
+
|
652 |
+
# 如果是旧格式(列表格式)
|
653 |
+
elif isinstance(parsed_todos_data, list):
|
654 |
+
output_for_df = []
|
655 |
+
|
656 |
+
# 检查列表是否为空
|
657 |
+
if not parsed_todos_data:
|
658 |
+
logger.warning("LLM 返回了空列表,无法生成 ToDo 项目")
|
659 |
+
return [[1, "未能生成待办事项", "-"]]
|
660 |
+
|
661 |
+
for i, item in enumerate(parsed_todos_data):
|
662 |
+
if isinstance(item, dict):
|
663 |
+
todo_content = item.get('todo_content', item.get('content', 'N/A'))
|
664 |
+
status = item.get('status', '未完成')
|
665 |
+
urgency = item.get('urgency', 'normal')
|
666 |
+
|
667 |
+
# 合并所有信息到一行
|
668 |
+
combined_content = todo_content
|
669 |
+
|
670 |
+
# 添加截止时间
|
671 |
+
if 'end_time' in item and item['end_time']:
|
672 |
+
try:
|
673 |
+
if isinstance(item['end_time'], str):
|
674 |
+
date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
|
675 |
+
combined_content += f" (截止时间: {date_part}"
|
676 |
+
else:
|
677 |
+
combined_content += f" (截止时间: {str(item['end_time'])}"
|
678 |
+
except Exception as e:
|
679 |
+
logger.warning(f"处理end_time时出错: {e}")
|
680 |
+
combined_content += " ("
|
681 |
+
else:
|
682 |
+
combined_content += " ("
|
683 |
+
|
684 |
+
# 添加地点
|
685 |
+
if 'location' in item and item['location']:
|
686 |
+
combined_content += f", 地点: {item['location']})"
|
687 |
+
else:
|
688 |
+
combined_content += ")"
|
689 |
+
|
690 |
+
# 设置重要等级
|
691 |
+
importance = "一般"
|
692 |
+
if urgency == "urgent":
|
693 |
+
importance = "紧急"
|
694 |
+
elif urgency == "important":
|
695 |
+
importance = "重要"
|
696 |
+
|
697 |
+
output_for_df.append([i + 1, combined_content, importance])
|
698 |
+
else:
|
699 |
+
# 如果不是字典,转换为字符串并添加到列表
|
700 |
+
try:
|
701 |
+
item_str = str(item) if item is not None else "未知项目"
|
702 |
+
output_for_df.append([i + 1, item_str, "一般"])
|
703 |
+
except Exception as e:
|
704 |
+
logger.warning(f"处理非字典项目时出错: {e}")
|
705 |
+
output_for_df.append([i + 1, "处理错误的项目", "一般"])
|
706 |
+
|
707 |
+
if not output_for_df:
|
708 |
+
logger.info("LLM 解析结果为空或无法转换为DataFrame格式。")
|
709 |
+
return [["info", "未发现待办事项", "-"]]
|
710 |
+
|
711 |
+
return output_for_df
|
712 |
+
|
713 |
+
except Exception as e:
|
714 |
+
logger.exception(f"调用 LLM 或解析时发生错误 (generate_todolist_from_text)")
|
715 |
+
return [["error", f"LLM call/parse error: {str(e)}", "-"]]
|
716 |
+
|
717 |
+
#gradio
|
718 |
+
def process(audio, image, request: gr.Request):
|
719 |
+
"""处理语音和图片的示例函数"""
|
720 |
+
# 获取并记录客户端IP
|
721 |
+
client_ip = get_client_ip(request, True)
|
722 |
+
print(f"Processing audio/image request from IP: {client_ip}")
|
723 |
+
|
724 |
+
if audio is not None:
|
725 |
+
sample_rate, audio_data = audio
|
726 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
727 |
+
else:
|
728 |
+
audio_info = "未收到音频"
|
729 |
+
|
730 |
+
if image is not None:
|
731 |
+
image_info = f"图片尺寸: {image.shape}"
|
732 |
+
else:
|
733 |
+
image_info = "未收到图片"
|
734 |
+
|
735 |
+
return audio_info, image_info
|
736 |
+
|
737 |
+
def respond(
|
738 |
+
message,
|
739 |
+
history: list[tuple[str, str]],
|
740 |
+
system_message,
|
741 |
+
max_tokens,
|
742 |
+
temperature,
|
743 |
+
top_p,
|
744 |
+
audio, # 多模态输入:音频
|
745 |
+
image # 多模态输入:图片
|
746 |
+
):
|
747 |
+
# ... (聊天回复逻辑基本保持不变, 但确保 client 使用的是配置好的 HF client)
|
748 |
+
# 1. 多模态处理接口 (其他人负责)
|
749 |
+
# processed_text_from_multimodal = multimodal_placeholder_function(audio, image)
|
750 |
+
# 多模态处理:调用讯飞API进行语音和图像识别
|
751 |
+
multimodal_content = ""
|
752 |
+
|
753 |
+
# 多模态处理配置已移至具体处理部分
|
754 |
+
|
755 |
+
if audio is not None:
|
756 |
+
try:
|
757 |
+
audio_sample_rate, audio_data = audio
|
758 |
+
multimodal_content += f"\n[音频信息: 采样率 {audio_sample_rate}Hz, 时长 {len(audio_data)/audio_sample_rate:.2f}秒]"
|
759 |
+
|
760 |
+
# 调用Azure Speech语音识别
|
761 |
+
azure_speech_config = get_hf_azure_speech_config()
|
762 |
+
azure_speech_key = azure_speech_config.get('key')
|
763 |
+
azure_speech_region = azure_speech_config.get('region')
|
764 |
+
|
765 |
+
if azure_speech_key and azure_speech_region:
|
766 |
+
import tempfile
|
767 |
+
import soundfile as sf
|
768 |
+
import os
|
769 |
+
|
770 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
771 |
+
sf.write(temp_audio.name, audio_data, audio_sample_rate)
|
772 |
+
temp_audio_path = temp_audio.name
|
773 |
+
|
774 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, temp_audio_path)
|
775 |
+
if audio_text:
|
776 |
+
multimodal_content += f"\n[语音识别结果: {audio_text}]"
|
777 |
+
else:
|
778 |
+
multimodal_content += "\n[语音识别失败]"
|
779 |
+
|
780 |
+
os.unlink(temp_audio_path)
|
781 |
+
else:
|
782 |
+
multimodal_content += "\n[Azure Speech API配置不完整,无法进行语音识别]"
|
783 |
+
|
784 |
+
except Exception as e:
|
785 |
+
multimodal_content += f"\n[音频处理错误: {str(e)}]"
|
786 |
+
|
787 |
+
if image is not None:
|
788 |
+
try:
|
789 |
+
multimodal_content += f"\n[图片信息: 尺寸 {image.shape}]"
|
790 |
+
|
791 |
+
# 调用讯飞图像识别
|
792 |
+
if xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
793 |
+
import tempfile
|
794 |
+
from PIL import Image
|
795 |
+
import os
|
796 |
+
|
797 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
|
798 |
+
if len(image.shape) == 3: # RGB图像
|
799 |
+
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
|
800 |
+
else: # 灰度图像
|
801 |
+
pil_image = Image.fromarray(image.astype('uint8'), 'L')
|
802 |
+
|
803 |
+
pil_image.save(temp_image.name, 'JPEG')
|
804 |
+
temp_image_path = temp_image.name
|
805 |
+
|
806 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=temp_image_path)
|
807 |
+
if image_text:
|
808 |
+
multimodal_content += f"\n[图像识别结果: {image_text}]"
|
809 |
+
else:
|
810 |
+
multimodal_content += "\n[图像识别失败]"
|
811 |
+
|
812 |
+
os.unlink(temp_image_path)
|
813 |
+
else:
|
814 |
+
multimodal_content += "\n[讯飞API配置不完整,无法进行图像识别]"
|
815 |
+
|
816 |
+
except Exception as e:
|
817 |
+
multimodal_content += f"\n[图像处理错误: {str(e)}]"
|
818 |
+
|
819 |
+
# 将多模态内容(或其处理结果)与用户文本消息结合
|
820 |
+
# combined_message = message
|
821 |
+
# if multimodal_content: # 如果有多模态内容,则附加
|
822 |
+
# combined_message += "\n" + multimodal_content
|
823 |
+
# 为了聊天模型的连贯性,聊天部分可能只使用原始 message
|
824 |
+
# 而 ToDoList 生成则使用 combined_message
|
825 |
+
|
826 |
+
# 聊天回复生成
|
827 |
+
chat_messages = [{"role": "system", "content": system_message}]
|
828 |
+
for val in history:
|
829 |
+
if val[0]:
|
830 |
+
chat_messages.append({"role": "user", "content": val[0]})
|
831 |
+
if val[1]:
|
832 |
+
chat_messages.append({"role": "assistant", "content": val[1]})
|
833 |
+
chat_messages.append({"role": "user", "content": message}) # 聊天机器人使用原始消息
|
834 |
+
|
835 |
+
chat_response_stream = ""
|
836 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
837 |
+
logger.error("Filter API 配置不完整,无法调用 LLM333。")
|
838 |
+
yield "Filter API 配置不完整,无法提供聊天回复。", []
|
839 |
+
return
|
840 |
+
|
841 |
+
headers = {
|
842 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
843 |
+
"Accept": "application/json"
|
844 |
+
}
|
845 |
+
payload = {
|
846 |
+
"model": Filter_MODEL_NAME,
|
847 |
+
"messages": chat_messages,
|
848 |
+
"temperature": temperature,
|
849 |
+
"top_p": top_p,
|
850 |
+
"max_tokens": max_tokens,
|
851 |
+
"stream": True # 聊天通常需要流式输出
|
852 |
+
}
|
853 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
854 |
+
|
855 |
+
try:
|
856 |
+
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
857 |
+
response.raise_for_status() # 检查 HTTP 错误
|
858 |
+
|
859 |
+
for chunk in response.iter_content(chunk_size=None):
|
860 |
+
if chunk:
|
861 |
+
try:
|
862 |
+
# NVIDIA API 的流式输出是 SSE 格式,需要解析
|
863 |
+
# 每一行以 'data: ' 开头,后面是 JSON
|
864 |
+
for line in chunk.decode('utf-8').splitlines():
|
865 |
+
if line.startswith('data: '):
|
866 |
+
json_data = line[len('data: '):]
|
867 |
+
if json_data.strip() == '[DONE]':
|
868 |
+
break
|
869 |
+
data = json.loads(json_data)
|
870 |
+
# 检查 choices 列表是否存在且不为空
|
871 |
+
if 'choices' in data and len(data['choices']) > 0:
|
872 |
+
token = data['choices'][0]['delta'].get('content', '')
|
873 |
+
if token:
|
874 |
+
chat_response_stream += token
|
875 |
+
yield chat_response_stream, []
|
876 |
+
except json.JSONDecodeError:
|
877 |
+
logger.warning(f"无法解析流式响应块: {chunk.decode('utf-8')}")
|
878 |
+
except Exception as e:
|
879 |
+
logger.error(f"处理流式响应时发生错误: {e}")
|
880 |
+
yield chat_response_stream + f"\n\n错误: {e}", []
|
881 |
+
|
882 |
+
except requests.exceptions.RequestException as e:
|
883 |
+
logger.error(f"调用 NVIDIA API 失败: {e}")
|
884 |
+
yield f"调用 NVIDIA API 失败: {e}", []
|
885 |
+
|
886 |
+
# 全局变量存储所有待办事项
|
887 |
+
all_todos_global = []
|
888 |
+
|
889 |
+
# 创建自定义的聊天界面
|
890 |
+
with gr.Blocks() as app:
|
891 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
|
892 |
+
|
893 |
+
with gr.Row():
|
894 |
+
with gr.Column(scale=2):
|
895 |
+
gr.Markdown("## Chat Interface")
|
896 |
+
chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages") # 推荐使用 type="messages"
|
897 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
|
898 |
+
|
899 |
+
with gr.Row():
|
900 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
901 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
902 |
+
|
903 |
+
with gr.Accordion("高级设置", open=False):
|
904 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
905 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)") # 增加聊天模型参数范围
|
906 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
|
907 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
|
908 |
+
|
909 |
+
with gr.Row():
|
910 |
+
submit_btn = gr.Button("发送", variant="primary")
|
911 |
+
clear_btn = gr.Button("清除聊天和ToDo")
|
912 |
+
|
913 |
+
with gr.Column(scale=1):
|
914 |
+
gr.Markdown("## Generated ToDo List")
|
915 |
+
todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
|
916 |
+
datatype=["number", "str", "str"],
|
917 |
+
row_count=(0, "dynamic"),
|
918 |
+
col_count=(3, "fixed"),
|
919 |
+
label="待办事项列表")
|
920 |
+
|
921 |
+
def user(user_message, chat_history):
|
922 |
+
# 将用户消息添加到聊天记录 (Gradio type="messages" 格式)
|
923 |
+
if not chat_history: chat_history = []
|
924 |
+
chat_history.append({"role": "user", "content": user_message})
|
925 |
+
return "", chat_history
|
926 |
+
|
927 |
+
def bot_interaction(chat_history, system_message, max_tokens, temperature, top_p, audio, image):
|
928 |
+
user_message_for_chat = ""
|
929 |
+
if chat_history and chat_history[-1]["role"] == "user":
|
930 |
+
user_message_for_chat = chat_history[-1]["content"]
|
931 |
+
|
932 |
+
# 准备用于 ToDoList 生成的输入文本 (多模态部分由其他人负责)
|
933 |
+
text_for_todolist = user_message_for_chat
|
934 |
+
# 可以在这里添加从 audio/image 提取文本的逻辑,并附加到 text_for_todolist
|
935 |
+
# multimodal_text = process_multimodal_inputs(audio, image) # 假设的函数
|
936 |
+
# if multimodal_text:
|
937 |
+
# text_for_todolist += "\n" + multimodal_text
|
938 |
+
|
939 |
+
# 1. 生成聊天回复 (流式)
|
940 |
+
# 转换 chat_history 从 [{'role':'user', 'content':'...'}, ...] 到 [('user_msg', 'bot_msg'), ...]
|
941 |
+
# respond 函数期望的是 history: list[tuple[str, str]]
|
942 |
+
# 但 Gradio type="messages" 的 chatbot.value 是 [{'role': ..., 'content': ...}, ...]
|
943 |
+
# 需要转换
|
944 |
+
formatted_history_for_respond = []
|
945 |
+
temp_user_msg = None
|
946 |
+
for item in chat_history[:-1]: #排除最后一条用户消息,因为它会作为当前message传入respond
|
947 |
+
if item["role"] == "user":
|
948 |
+
temp_user_msg = item["content"]
|
949 |
+
elif item["role"] == "assistant" and temp_user_msg is not None:
|
950 |
+
formatted_history_for_respond.append((temp_user_msg, item["content"]))
|
951 |
+
temp_user_msg = None
|
952 |
+
elif item["role"] == "assistant" and temp_user_msg is None: # Bot 先说话的情况
|
953 |
+
formatted_history_for_respond.append(("", item["content"]))
|
954 |
+
|
955 |
+
chat_stream_generator = respond(
|
956 |
+
user_message_for_chat,
|
957 |
+
formatted_history_for_respond, # 传递转换后的历史
|
958 |
+
system_message,
|
959 |
+
max_tokens,
|
960 |
+
temperature,
|
961 |
+
top_p,
|
962 |
+
audio,
|
963 |
+
image
|
964 |
+
)
|
965 |
+
|
966 |
+
full_chat_response = ""
|
967 |
+
current_todos = []
|
968 |
+
|
969 |
+
for chat_response_part, _ in chat_stream_generator:
|
970 |
+
full_chat_response = chat_response_part
|
971 |
+
# 更新 chat_history (Gradio type="messages" 格式)
|
972 |
+
if chat_history and chat_history[-1]["role"] == "user":
|
973 |
+
# 如果最后一条是用户消息,添加机器人回复
|
974 |
+
# 但由于是流式,我们可能需要先添加一个空的 assistant 消息,然后更新它
|
975 |
+
# 或者,等待流结束后一次性添加
|
976 |
+
# 为了简化,我们先假设 respond 返回的是完整回复,或者在循环外更新
|
977 |
+
pass # 流式更新 chatbot 在 submit_btn.click 中处理
|
978 |
+
yield chat_history + [[None, full_chat_response]], current_todos # 临时做法,需要适配Gradio的流式更新
|
979 |
+
|
980 |
+
# 流式结束后,更新 chat_history 中的最后一条 assistant 消息
|
981 |
+
if chat_history and full_chat_response:
|
982 |
+
# 查找最后一条用户消息,在其后添加或更新机器人回复
|
983 |
+
# 这种方式对于 type="messages" 更友好
|
984 |
+
# 实际上,Gradio 的 chatbot 更新应该在 .then() 中处理,这里先模拟
|
985 |
+
# chat_history.append({"role": "assistant", "content": full_chat_response})
|
986 |
+
# 这个 yield 应该在 submit_btn.click 的 .then() 中处理 chatbot 的更新
|
987 |
+
# 这里我们先专注于 ToDo 生成
|
988 |
+
pass # chatbot 更新由 Gradio 机制处理
|
989 |
+
|
990 |
+
# 2. 聊天回复完成后,生成/更新 ToDoList
|
991 |
+
if text_for_todolist:
|
992 |
+
# 使用一个唯一的 ID,例如基于时间戳或随机数,如果需要区分不同输入的 ToDo
|
993 |
+
message_id_for_todo = f"hf_app_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
994 |
+
new_todo_items = generate_todolist_from_text(text_for_todolist, message_id_for_todo)
|
995 |
+
current_todos = new_todo_items
|
996 |
+
|
997 |
+
# bot_interaction 应该返回 chatbot 的最终状态和 todolist_df 的数据
|
998 |
+
# chatbot 的最终状态是 chat_history + assistant 的回复
|
999 |
+
final_chat_history = list(chat_history) # 复制
|
1000 |
+
if full_chat_response:
|
1001 |
+
final_chat_history.append({"role": "assistant", "content": full_chat_response})
|
1002 |
+
|
1003 |
+
yield final_chat_history, current_todos
|
1004 |
+
|
1005 |
+
# 连接事件 (适配 type="messages")
|
1006 |
+
# Gradio 的流式更新通常是:
|
1007 |
+
# 1. user 函数准备输入,返回 (空输入框, 更新后的聊天记录)
|
1008 |
+
# 2. bot_interaction 函数是一个生成器,yield (部分聊天记录, 部分ToDo)
|
1009 |
+
# msg.submit 和 submit_btn.click 的 outputs 需要对应 bot_interaction 的 yield
|
1010 |
+
|
1011 |
+
# 简化版,非流式更新 chatbot,流式更新由 respond 内部的 yield 控制
|
1012 |
+
# 但 respond 的 yield 格式 (str, list) 与 bot_interaction (list, list) 不同
|
1013 |
+
# 需要调整 respond 的 yield 或 bot_interaction 的处理
|
1014 |
+
|
1015 |
+
# 调整后的事件处理,以更好地支持流式聊天和ToDo更新
|
1016 |
+
def process_filtered_result_for_todo(filtered_result, content, source_type):
|
1017 |
+
"""处理过滤结果并生成todolist的辅助函数"""
|
1018 |
+
todos = []
|
1019 |
+
|
1020 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
1021 |
+
logger.error(f"{source_type} Filter 模块处理失败: {filtered_result['error']}")
|
1022 |
+
todos = [["Error", f"{source_type}: {filtered_result['error']}", "Filter Failed"]]
|
1023 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
1024 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
1025 |
+
todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
1026 |
+
elif isinstance(filtered_result, list):
|
1027 |
+
# 处理列表类型的过滤结果
|
1028 |
+
category = None
|
1029 |
+
if filtered_result:
|
1030 |
+
for item in filtered_result:
|
1031 |
+
if isinstance(item, dict) and "分类" in item:
|
1032 |
+
category = item["分类"]
|
1033 |
+
break
|
1034 |
+
|
1035 |
+
if category == "其他":
|
1036 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
1037 |
+
todos = [["Info", f"{source_type}: 消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
1038 |
+
else:
|
1039 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
|
1040 |
+
if content:
|
1041 |
+
msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1042 |
+
todos = generate_todolist_from_text(content, msg_id_todo)
|
1043 |
+
# 为每个todo添加来源标识
|
1044 |
+
for todo in todos:
|
1045 |
+
if len(todo) > 1:
|
1046 |
+
todo[1] = f"[{source_type}] {todo[1]}"
|
1047 |
+
else:
|
1048 |
+
# 如果是字典但不是"其他"分类
|
1049 |
+
logger.info(f"{source_type}消息被 Filter 模块归类为 '{filtered_result.get('分类') if isinstance(filtered_result, dict) else '未知'}',继续生成 ToDo List。")
|
1050 |
+
if content:
|
1051 |
+
msg_id_todo = f"hf_app_todo_{source_type}_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1052 |
+
todos = generate_todolist_from_text(content, msg_id_todo)
|
1053 |
+
# 为每个todo添加来源标识
|
1054 |
+
for todo in todos:
|
1055 |
+
if len(todo) > 1:
|
1056 |
+
todo[1] = f"[{source_type}] {todo[1]}"
|
1057 |
+
|
1058 |
+
return todos
|
1059 |
+
|
1060 |
+
def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f, request: gr.Request):
|
1061 |
+
global all_todos_global
|
1062 |
+
|
1063 |
+
# 获取并记录客户端IP
|
1064 |
+
client_ip = get_client_ip(request, True)
|
1065 |
+
print(f"Processing request from IP: {client_ip}")
|
1066 |
+
|
1067 |
+
# 首先处理多模态输入,获取多模态内容
|
1068 |
+
multimodal_text_content = ""
|
1069 |
+
# 添加调试日志
|
1070 |
+
logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
|
1071 |
+
|
1072 |
+
# 获取Azure Speech配置
|
1073 |
+
azure_speech_config = get_hf_azure_speech_config()
|
1074 |
+
azure_speech_key = azure_speech_config.get('key')
|
1075 |
+
azure_speech_region = azure_speech_config.get('region')
|
1076 |
+
|
1077 |
+
# 添加调试日志
|
1078 |
+
logger.info(f"Azure Speech配置状态 - key: {bool(azure_speech_key)}, region: {bool(azure_speech_region)}")
|
1079 |
+
|
1080 |
+
# 处理音频输入(使用Azure Speech服务)
|
1081 |
+
if audio_f is not None and azure_speech_key and azure_speech_region:
|
1082 |
+
logger.info("开始处理音频输入...")
|
1083 |
+
try:
|
1084 |
+
audio_sample_rate, audio_data = audio_f
|
1085 |
+
logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
|
1086 |
+
|
1087 |
+
# 保存音频为.wav文件
|
1088 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
1089 |
+
save_audio(audio_f, audio_filename)
|
1090 |
+
logger.info(f"音频已保存: {audio_filename}")
|
1091 |
+
|
1092 |
+
# 调用Azure Speech服务处理音频
|
1093 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
|
1094 |
+
logger.info(f"音频识别结果: {audio_text}")
|
1095 |
+
if audio_text:
|
1096 |
+
multimodal_text_content += f"音频内容: {audio_text}"
|
1097 |
+
logger.info("音频处理完成")
|
1098 |
+
else:
|
1099 |
+
logger.warning("音频处理失败")
|
1100 |
+
except Exception as e:
|
1101 |
+
logger.error(f"音频处理错误: {str(e)}")
|
1102 |
+
elif audio_f is not None:
|
1103 |
+
logger.warning("音频文件存在但Azure Speech配置不完整,跳过音频处理")
|
1104 |
+
|
1105 |
+
# 处理图像输入(使用Azure Computer Vision服务)
|
1106 |
+
if image_f is not None:
|
1107 |
+
logger.info("开始处理图像输入...")
|
1108 |
+
try:
|
1109 |
+
logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
|
1110 |
+
|
1111 |
+
# 保存图片为.jpg文件
|
1112 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
1113 |
+
save_image(image_f, image_filename)
|
1114 |
+
logger.info(f"图像已保存: {image_filename}")
|
1115 |
+
|
1116 |
+
# 调用tools.py中的image_to_str方法处理图片
|
1117 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
1118 |
+
logger.info(f"图像识别结果: {image_text}")
|
1119 |
+
if image_text:
|
1120 |
+
if multimodal_text_content: # 如果已有音频内容,添加分隔符
|
1121 |
+
multimodal_text_content += "\n"
|
1122 |
+
multimodal_text_content += f"图像内容: {image_text}"
|
1123 |
+
logger.info("图像处理完成")
|
1124 |
+
else:
|
1125 |
+
logger.warning("图像处理失败")
|
1126 |
+
except Exception as e:
|
1127 |
+
logger.error(f"图像处理错误: {str(e)}")
|
1128 |
+
elif image_f is not None:
|
1129 |
+
logger.warning("图像文件存在但处理失败,跳过图像处理")
|
1130 |
+
|
1131 |
+
# 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
|
1132 |
+
final_user_content = user_msg_content.strip() if user_msg_content else ""
|
1133 |
+
if not final_user_content and multimodal_text_content:
|
1134 |
+
final_user_content = multimodal_text_content
|
1135 |
+
logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
|
1136 |
+
elif final_user_content and multimodal_text_content:
|
1137 |
+
# 用户有文本输入,多模态内容作为补充
|
1138 |
+
final_user_content = f"{final_user_content}\n{multimodal_text_content}"
|
1139 |
+
logger.info(f"用户有文本输入,多模态内容作为补充")
|
1140 |
+
|
1141 |
+
# 如果最终还是没有任何内容,提供默认提示
|
1142 |
+
if not final_user_content:
|
1143 |
+
final_user_content = "[无输入内容]"
|
1144 |
+
logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
|
1145 |
+
|
1146 |
+
logger.info(f"最终用户输入内容: {final_user_content}")
|
1147 |
+
|
1148 |
+
# 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
|
1149 |
+
if not ch_history: ch_history = []
|
1150 |
+
ch_history.append({"role": "user", "content": final_user_content})
|
1151 |
+
yield ch_history, [] # 更新聊天,ToDo 列表暂时不变
|
1152 |
+
|
1153 |
+
# 2. 流式生成机器人回复并更新聊天记录
|
1154 |
+
# 转换 chat_history 为 respond 函数期望的格式
|
1155 |
+
formatted_hist_for_respond = []
|
1156 |
+
temp_user_msg_for_hist = None
|
1157 |
+
# 使用 ch_history[:-1] 因为当前用户消息已在 ch_history 中
|
1158 |
+
for item_hist in ch_history[:-1]:
|
1159 |
+
if item_hist["role"] == "user":
|
1160 |
+
temp_user_msg_for_hist = item_hist["content"]
|
1161 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
|
1162 |
+
formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
|
1163 |
+
temp_user_msg_for_hist = None
|
1164 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
|
1165 |
+
formatted_hist_for_respond.append(("", item_hist["content"]))
|
1166 |
+
|
1167 |
+
# 准备一个 assistant 消息的槽位
|
1168 |
+
ch_history.append({"role": "assistant", "content": ""})
|
1169 |
+
|
1170 |
+
full_bot_response = ""
|
1171 |
+
# 使用最终确定的用户内容进行对话
|
1172 |
+
for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
1173 |
+
full_bot_response = bot_response_token
|
1174 |
+
ch_history[-1]["content"] = full_bot_response # 更新最后一条 assistant 消息
|
1175 |
+
yield ch_history, [] # 流式更新聊天,ToDo 列表不变
|
1176 |
+
|
1177 |
+
# 3. 生成 ToDoList - 分别处理音频、图片和文字输入
|
1178 |
+
new_todos_list = []
|
1179 |
+
|
1180 |
+
# 分别处理文字输入
|
1181 |
+
if user_msg_content.strip():
|
1182 |
+
logger.info(f"处理文字输入生成ToDo: {user_msg_content.strip()}")
|
1183 |
+
text_filtered_result = filter_message_with_llm(user_msg_content.strip())
|
1184 |
+
text_todos = process_filtered_result_for_todo(text_filtered_result, user_msg_content.strip(), "文字")
|
1185 |
+
new_todos_list.extend(text_todos)
|
1186 |
+
|
1187 |
+
# 分别处理音频输入
|
1188 |
+
if audio_f is not None and azure_speech_key and azure_speech_region:
|
1189 |
+
try:
|
1190 |
+
audio_sample_rate, audio_data = audio_f
|
1191 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
1192 |
+
save_audio(audio_f, audio_filename)
|
1193 |
+
audio_text = azure_speech_to_text(azure_speech_key, azure_speech_region, audio_filename)
|
1194 |
+
if audio_text:
|
1195 |
+
logger.info(f"处理音频输入生成ToDo: {audio_text}")
|
1196 |
+
audio_filtered_result = filter_message_with_llm(audio_text)
|
1197 |
+
audio_todos = process_filtered_result_for_todo(audio_filtered_result, audio_text, "音频")
|
1198 |
+
new_todos_list.extend(audio_todos)
|
1199 |
+
except Exception as e:
|
1200 |
+
logger.error(f"音频处理错误: {str(e)}")
|
1201 |
+
|
1202 |
+
# 分别处理图片输入
|
1203 |
+
if image_f is not None:
|
1204 |
+
try:
|
1205 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
1206 |
+
save_image(image_f, image_filename)
|
1207 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
1208 |
+
if image_text:
|
1209 |
+
logger.info(f"处理图片输入生成ToDo: {image_text}")
|
1210 |
+
image_filtered_result = filter_message_with_llm(image_text)
|
1211 |
+
image_todos = process_filtered_result_for_todo(image_filtered_result, image_text, "图片")
|
1212 |
+
new_todos_list.extend(image_todos)
|
1213 |
+
except Exception as e:
|
1214 |
+
logger.error(f"图片处理错误: {str(e)}")
|
1215 |
+
|
1216 |
+
# 如果没有任何有效输入,使用原有逻辑
|
1217 |
+
if not new_todos_list and final_user_content:
|
1218 |
+
logger.info(f"使用整合内容生成ToDo: {final_user_content}")
|
1219 |
+
filtered_result = filter_message_with_llm(final_user_content)
|
1220 |
+
|
1221 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
1222 |
+
logger.error(f"Filter 模块处理失败: {filtered_result['error']}")
|
1223 |
+
# 可以选择在这里显示错误信息给用户
|
1224 |
+
new_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
|
1225 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
1226 |
+
logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
1227 |
+
new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
1228 |
+
elif isinstance(filtered_result, list):
|
1229 |
+
# 如果返回的是列表,尝试从列表中获取分类信息
|
1230 |
+
category = None
|
1231 |
+
|
1232 |
+
# 检查列表是否为空
|
1233 |
+
if not filtered_result:
|
1234 |
+
logger.warning("Filter 模块返回了空列表,将继续生成 ToDo List。")
|
1235 |
+
if final_user_content:
|
1236 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1237 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
1238 |
+
# 将新的待办事项添加到全局列表中
|
1239 |
+
if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
|
1240 |
+
# 重新分配ID以确保连续性
|
1241 |
+
for i, todo in enumerate(new_todos_list):
|
1242 |
+
todo[0] = len(all_todos_global) + i + 1
|
1243 |
+
all_todos_global.extend(new_todos_list)
|
1244 |
+
yield ch_history, all_todos_global
|
1245 |
+
return
|
1246 |
+
|
1247 |
+
# 确保列表中至少有一个元素且是字典类型
|
1248 |
+
valid_item = None
|
1249 |
+
for item in filtered_result:
|
1250 |
+
if isinstance(item, dict):
|
1251 |
+
valid_item = item
|
1252 |
+
if "分类" in item:
|
1253 |
+
category = item["分类"]
|
1254 |
+
break
|
1255 |
+
|
1256 |
+
# 如果没有找到有效的字典元素,记录警告并继续生成ToDo
|
1257 |
+
if valid_item is None:
|
1258 |
+
logger.warning(f"Filter 模块返回的列表中没有有效的字典元素: {filtered_result}")
|
1259 |
+
if final_user_content:
|
1260 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1261 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
1262 |
+
# 将新的待办事项添加到全局列表中
|
1263 |
+
if new_todos_list and not (len(new_todos_list) == 1 and "Info" in str(new_todos_list[0])):
|
1264 |
+
# 重新分配ID以确保连续性
|
1265 |
+
for i, todo in enumerate(new_todos_list):
|
1266 |
+
todo[0] = len(all_todos_global) + i + 1
|
1267 |
+
all_todos_global.extend(new_todos_list)
|
1268 |
+
yield ch_history, all_todos_global
|
1269 |
+
return
|
1270 |
+
|
1271 |
+
if category == "其他":
|
1272 |
+
logger.info(f"消息被 Filter 模块归类为 '其他',不生成 ToDo List。")
|
1273 |
+
new_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
1274 |
+
else:
|
1275 |
+
logger.info(f"消息被 Filter 模块归类为 '{category if category else '未知'}',继续生成 ToDo List。")
|
1276 |
+
# 如果 Filter 结果不是"其他",则继续生成 ToDoList
|
1277 |
+
if final_user_content:
|
1278 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1279 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
1280 |
+
else:
|
1281 |
+
# 如果是字典但不是"其他"分类
|
1282 |
+
logger.info(f"消息被 Filter 模块归类为 '{filtered_result.get('分类')}',继续生成 ToDo List。")
|
1283 |
+
# 如果 Filter 结果不是"其他",则继续生成 ToDoList
|
1284 |
+
if final_user_content:
|
1285 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
1286 |
+
new_todos_list = generate_todolist_from_text(final_user_content, msg_id_todo)
|
1287 |
+
|
1288 |
+
# 将新的待办事项添加到全局列表中(排除信息性消息)
|
1289 |
+
if new_todos_list and not (len(new_todos_list) == 1 and ("Info" in str(new_todos_list[0]) or "Error" in str(new_todos_list[0]))):
|
1290 |
+
# 重新分配ID以确保连续性
|
1291 |
+
for i, todo in enumerate(new_todos_list):
|
1292 |
+
todo[0] = len(all_todos_global) + i + 1
|
1293 |
+
all_todos_global.extend(new_todos_list)
|
1294 |
+
|
1295 |
+
yield ch_history, all_todos_global # 最终更新聊天和完整的ToDo列表
|
1296 |
+
|
1297 |
+
submit_btn.click(
|
1298 |
+
handle_submit,
|
1299 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
1300 |
+
[chatbot, todolist_df]
|
1301 |
+
)
|
1302 |
+
msg.submit(
|
1303 |
+
handle_submit,
|
1304 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
1305 |
+
[chatbot, todolist_df]
|
1306 |
+
)
|
1307 |
+
|
1308 |
+
def clear_all():
|
1309 |
+
global all_todos_global
|
1310 |
+
all_todos_global = [] # 清除全局待办事项列表
|
1311 |
+
return None, None, "" # 清除 chatbot, todolist_df, 和 msg 输入框
|
1312 |
+
clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
|
1313 |
+
|
1314 |
+
# 旧的 Audio/Image Processing Tab (保持不变或按需修改)
|
1315 |
+
with gr.Tab("Audio/Image Processing (Original)"):
|
1316 |
+
gr.Markdown("## 处理音频和图片")
|
1317 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
1318 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
1319 |
+
process_btn = gr.Button("处理", variant="primary")
|
1320 |
+
audio_output = gr.Textbox(label="音频信息")
|
1321 |
+
image_output = gr.Textbox(label="图片信息")
|
1322 |
+
|
1323 |
+
process_btn.click(
|
1324 |
+
process,
|
1325 |
+
inputs=[audio_processor, image_processor],
|
1326 |
+
outputs=[audio_output, image_output]
|
1327 |
+
)
|
1328 |
+
|
1329 |
+
if __name__ == "__main__":
|
1330 |
+
app.launch(debug=True)
|
app_pro.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
import yaml
|
5 |
+
import re
|
6 |
+
import logging
|
7 |
+
import io
|
8 |
+
import sys
|
9 |
+
import re
|
10 |
+
from datetime import datetime, timezone, timedelta
|
11 |
+
import requests
|
12 |
+
from tools import * #gege的多模态
|
13 |
+
|
14 |
+
|
15 |
+
CONFIG = None
|
16 |
+
HF_CONFIG_PATH = Path(__file__).parent / "todogen_LLM_config.yaml"
|
17 |
+
|
18 |
+
def load_hf_config():
|
19 |
+
"""加载YAML配置文件"""
|
20 |
+
global CONFIG
|
21 |
+
if CONFIG is None:
|
22 |
+
try:
|
23 |
+
with open(HF_CONFIG_PATH, 'r', encoding='utf-8') as f:
|
24 |
+
CONFIG = yaml.safe_load(f)
|
25 |
+
print(f"✅ 配置已加载: {HF_CONFIG_PATH}")
|
26 |
+
except FileNotFoundError:
|
27 |
+
print(f"❌ 错误: 配置文件 {HF_CONFIG_PATH} 未找到。请确保它在 hf 目录下。")
|
28 |
+
CONFIG = {}
|
29 |
+
except Exception as e:
|
30 |
+
print(f"❌ 加载配置文件 {HF_CONFIG_PATH} 时出错: {e}")
|
31 |
+
CONFIG = {}
|
32 |
+
return CONFIG
|
33 |
+
|
34 |
+
def get_hf_openai_config():
|
35 |
+
"""获取OpenAI API配置"""
|
36 |
+
config = load_hf_config()
|
37 |
+
return config.get('openai', {})
|
38 |
+
|
39 |
+
def get_hf_openai_filter_config():
|
40 |
+
"""获取Filter API配置"""
|
41 |
+
config = load_hf_config()
|
42 |
+
return config.get('openai_filter', {})
|
43 |
+
|
44 |
+
def get_hf_xunfei_config():
|
45 |
+
"""获取讯飞API配置"""
|
46 |
+
config = load_hf_config()
|
47 |
+
return config.get('xunfei', {})
|
48 |
+
|
49 |
+
def get_hf_paths_config():
|
50 |
+
"""获取文件路径配置"""
|
51 |
+
config = load_hf_config()
|
52 |
+
base = Path(__file__).resolve().parent
|
53 |
+
paths_cfg = config.get('paths', {})
|
54 |
+
return {
|
55 |
+
'base_dir': base,
|
56 |
+
'prompt_template': base / paths_cfg.get('prompt_template', 'prompt_template.txt'),
|
57 |
+
'true_positive_examples': base / paths_cfg.get('true_positive_examples', 'TruePositive_few_shot.txt'),
|
58 |
+
'false_positive_examples': base / paths_cfg.get('false_positive_examples', 'FalsePositive_few_shot.txt'),
|
59 |
+
}
|
60 |
+
|
61 |
+
llm_config = get_hf_openai_config()
|
62 |
+
NVIDIA_API_BASE_URL = llm_config.get('base_url')
|
63 |
+
NVIDIA_API_KEY = llm_config.get('api_key')
|
64 |
+
NVIDIA_MODEL_NAME = llm_config.get('model')
|
65 |
+
|
66 |
+
filter_config = get_hf_openai_filter_config()
|
67 |
+
Filter_API_BASE_URL = filter_config.get('base_url_filter')
|
68 |
+
Filter_API_KEY = filter_config.get('api_key_filter')
|
69 |
+
Filter_MODEL_NAME = filter_config.get('model_filter')
|
70 |
+
|
71 |
+
if not NVIDIA_API_BASE_URL or not NVIDIA_API_KEY or not NVIDIA_MODEL_NAME:
|
72 |
+
print("❌ 错误: NVIDIA API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai 部分。")
|
73 |
+
NVIDIA_API_BASE_URL = ""
|
74 |
+
NVIDIA_API_KEY = ""
|
75 |
+
NVIDIA_MODEL_NAME = ""
|
76 |
+
|
77 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
78 |
+
print("❌ 错误: Filter API 配置不完整。请检查 todogen_LLM_config.yaml 中的 openai_filter 部分。")
|
79 |
+
Filter_API_BASE_URL = ""
|
80 |
+
Filter_API_KEY = ""
|
81 |
+
Filter_MODEL_NAME = ""
|
82 |
+
|
83 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
84 |
+
logger = logging.getLogger(__name__)
|
85 |
+
|
86 |
+
def load_single_few_shot_file_hf(file_path: Path) -> str:
|
87 |
+
"""加载单个few-shot示例文件并转义大括号"""
|
88 |
+
try:
|
89 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
90 |
+
content = f.read()
|
91 |
+
escaped_content = content.replace('{', '{{').replace('}', '}}')
|
92 |
+
return escaped_content
|
93 |
+
except FileNotFoundError:
|
94 |
+
return ""
|
95 |
+
except Exception:
|
96 |
+
return ""
|
97 |
+
|
98 |
+
PROMPT_TEMPLATE_CONTENT = ""
|
99 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = ""
|
100 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = ""
|
101 |
+
|
102 |
+
def load_prompt_data_hf():
|
103 |
+
"""加载提示词模板和示例数据"""
|
104 |
+
global PROMPT_TEMPLATE_CONTENT, TRUE_POSITIVE_EXAMPLES_CONTENT, FALSE_POSITIVE_EXAMPLES_CONTENT
|
105 |
+
paths = get_hf_paths_config()
|
106 |
+
try:
|
107 |
+
with open(paths['prompt_template'], 'r', encoding='utf-8') as f:
|
108 |
+
PROMPT_TEMPLATE_CONTENT = f.read()
|
109 |
+
except FileNotFoundError:
|
110 |
+
PROMPT_TEMPLATE_CONTENT = "Error: Prompt template not found."
|
111 |
+
|
112 |
+
TRUE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['true_positive_examples'])
|
113 |
+
FALSE_POSITIVE_EXAMPLES_CONTENT = load_single_few_shot_file_hf(paths['false_positive_examples'])
|
114 |
+
|
115 |
+
load_prompt_data_hf()
|
116 |
+
|
117 |
+
def _process_parsed_json(parsed_data):
|
118 |
+
"""处理解析后的JSON数据,确保格式正确"""
|
119 |
+
try:
|
120 |
+
if isinstance(parsed_data, list):
|
121 |
+
if not parsed_data:
|
122 |
+
return [{}]
|
123 |
+
|
124 |
+
processed_list = []
|
125 |
+
for item in parsed_data:
|
126 |
+
if isinstance(item, dict):
|
127 |
+
processed_list.append(item)
|
128 |
+
else:
|
129 |
+
try:
|
130 |
+
processed_list.append({"content": str(item)})
|
131 |
+
except:
|
132 |
+
processed_list.append({"content": "无法转换的项目"})
|
133 |
+
|
134 |
+
if not processed_list:
|
135 |
+
return [{}]
|
136 |
+
|
137 |
+
return processed_list
|
138 |
+
|
139 |
+
elif isinstance(parsed_data, dict):
|
140 |
+
return parsed_data
|
141 |
+
|
142 |
+
else:
|
143 |
+
return {"content": str(parsed_data)}
|
144 |
+
|
145 |
+
except Exception as e:
|
146 |
+
return {"error": f"Error processing parsed JSON: {e}"}
|
147 |
+
|
148 |
+
def json_parser(text: str) -> dict:
|
149 |
+
"""从文本中解析JSON数据,支持多种格式"""
|
150 |
+
try:
|
151 |
+
try:
|
152 |
+
parsed_data = json.loads(text)
|
153 |
+
return _process_parsed_json(parsed_data)
|
154 |
+
except json.JSONDecodeError:
|
155 |
+
pass
|
156 |
+
|
157 |
+
match = re.search(r'```(?:json)?\n(.*?)```', text, re.DOTALL)
|
158 |
+
if match:
|
159 |
+
json_str = match.group(1).strip()
|
160 |
+
json_str = re.sub(r',\s*]', ']', json_str)
|
161 |
+
json_str = re.sub(r',\s*}', '}', json_str)
|
162 |
+
try:
|
163 |
+
parsed_data = json.loads(json_str)
|
164 |
+
return _process_parsed_json(parsed_data)
|
165 |
+
except json.JSONDecodeError:
|
166 |
+
pass
|
167 |
+
|
168 |
+
array_match = re.search(r'\[\s*\{.*?\}\s*(?:,\s*\{.*?\}\s*)*\]', text, re.DOTALL)
|
169 |
+
if array_match:
|
170 |
+
potential_json = array_match.group(0).strip()
|
171 |
+
try:
|
172 |
+
parsed_data = json.loads(potential_json)
|
173 |
+
return _process_parsed_json(parsed_data)
|
174 |
+
except json.JSONDecodeError:
|
175 |
+
pass
|
176 |
+
|
177 |
+
object_match = re.search(r'\{.*?\}', text, re.DOTALL)
|
178 |
+
if object_match:
|
179 |
+
potential_json = object_match.group(0).strip()
|
180 |
+
try:
|
181 |
+
parsed_data = json.loads(potential_json)
|
182 |
+
return _process_parsed_json(parsed_data)
|
183 |
+
except json.JSONDecodeError:
|
184 |
+
pass
|
185 |
+
|
186 |
+
return {"error": "No valid JSON block found or failed to parse", "raw_text": text}
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
return {"error": f"Unexpected error in json_parser: {e}", "raw_text": text}
|
190 |
+
|
191 |
+
def filter_message_with_llm(text_input: str, message_id: str = "user_input_001"):
|
192 |
+
"""使用LLM对消息进行分类过滤"""
|
193 |
+
mock_data = [(text_input, message_id)]
|
194 |
+
|
195 |
+
system_prompt = """
|
196 |
+
# 角色
|
197 |
+
你是一个专业的短信内容分析助手,根据输入判断内容的类型及可信度,为用户使用信息提供依据和便利。
|
198 |
+
|
199 |
+
# 任务
|
200 |
+
对于输入的多条数据,分析每一条数据内容(主键:`message_id`)属于【物流取件、缴费充值、待付(还)款、会议邀约、其他】的可能性百分比。
|
201 |
+
主要对于聊天、问候、回执、结果通知、上月账单等信息不需要收件人进行下一步处理的信息,直接归到其他类进行忽略
|
202 |
+
|
203 |
+
# 要求
|
204 |
+
1. 以json格式输出
|
205 |
+
2. content简洁提炼关键词,字符数<20以内
|
206 |
+
3. 输入条数和输出条数完全一样
|
207 |
+
|
208 |
+
# 输出示例
|
209 |
+
```
|
210 |
+
[
|
211 |
+
{"message_id":"1111111","content":"账单805.57元待还","物流取件":0,"欠费缴纳":99,"待付(还)款":1,"会议邀约":0,"其他":0, "分类":"欠费缴纳"},
|
212 |
+
{"message_id":"222222","content":"邀请你加入飞书视频会议","物流取件":0,"欠费缴纳":0,"待付(还)款":1,"会议邀约":100,"其他":0, "分类":"会议邀约"}
|
213 |
+
]
|
214 |
+
```
|
215 |
+
"""
|
216 |
+
|
217 |
+
llm_messages = [
|
218 |
+
{"role": "system", "content": system_prompt},
|
219 |
+
{"role": "user", "content": str(mock_data)}
|
220 |
+
]
|
221 |
+
|
222 |
+
try:
|
223 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
224 |
+
return [{"error": "Filter API configuration incomplete", "-": "-"}]
|
225 |
+
|
226 |
+
headers = {
|
227 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
228 |
+
"Accept": "application/json"
|
229 |
+
}
|
230 |
+
payload = {
|
231 |
+
"model": Filter_MODEL_NAME,
|
232 |
+
"messages": llm_messages,
|
233 |
+
"temperature": 0.0,
|
234 |
+
"top_p": 0.95,
|
235 |
+
"max_tokens": 1024,
|
236 |
+
"stream": False
|
237 |
+
}
|
238 |
+
|
239 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
240 |
+
|
241 |
+
try:
|
242 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
243 |
+
response.raise_for_status()
|
244 |
+
raw_llm_response = response.json()["choices"][0]["message"]["content"]
|
245 |
+
except requests.exceptions.RequestException as e:
|
246 |
+
return [{"error": f"Filter API call failed: {e}", "-": "-"}]
|
247 |
+
|
248 |
+
raw_llm_response = raw_llm_response.replace("```json", "").replace("```", "")
|
249 |
+
parsed_filter_data = json_parser(raw_llm_response)
|
250 |
+
|
251 |
+
if "error" in parsed_filter_data:
|
252 |
+
return [{"error": f"Filter LLM response parsing error: {parsed_filter_data['error']}"}]
|
253 |
+
|
254 |
+
if isinstance(parsed_filter_data, list) and parsed_filter_data:
|
255 |
+
for item in parsed_filter_data:
|
256 |
+
if isinstance(item, dict) and item.get("分类") == "欠费缴纳" and "缴费支出" in item.get("content", ""):
|
257 |
+
item["分类"] = "其他"
|
258 |
+
|
259 |
+
request_id_list = {message_id}
|
260 |
+
response_id_list = {item.get('message_id') for item in parsed_filter_data if isinstance(item, dict)}
|
261 |
+
diff = request_id_list - response_id_list
|
262 |
+
|
263 |
+
if diff:
|
264 |
+
for missed_id in diff:
|
265 |
+
parsed_filter_data.append({
|
266 |
+
"message_id": missed_id,
|
267 |
+
"content": text_input[:20],
|
268 |
+
"物流取件": 0,
|
269 |
+
"欠费缴纳": 0,
|
270 |
+
"待付(还)款": 0,
|
271 |
+
"会议邀约": 0,
|
272 |
+
"其他": 100,
|
273 |
+
"分类": "其他"
|
274 |
+
})
|
275 |
+
|
276 |
+
return parsed_filter_data
|
277 |
+
else:
|
278 |
+
return [{
|
279 |
+
"message_id": message_id,
|
280 |
+
"content": text_input[:20],
|
281 |
+
"物流取件": 0,
|
282 |
+
"欠费缴纳": 0,
|
283 |
+
"待付(还)款": 0,
|
284 |
+
"会议邀约": 0,
|
285 |
+
"其他": 100,
|
286 |
+
"分类": "其他",
|
287 |
+
"error": "Filter LLM returned empty or unexpected format"
|
288 |
+
}]
|
289 |
+
|
290 |
+
except Exception as e:
|
291 |
+
return [{
|
292 |
+
"message_id": message_id,
|
293 |
+
"content": text_input[:20],
|
294 |
+
"物流取件": 0,
|
295 |
+
"欠费缴纳": 0,
|
296 |
+
"待付(还)款": 0,
|
297 |
+
"会议邀约": 0,
|
298 |
+
"其他": 100,
|
299 |
+
"分类": "其他",
|
300 |
+
"error": f"Filter LLM call/parse error: {str(e)}"
|
301 |
+
}]
|
302 |
+
|
303 |
+
def generate_todolist_from_text(text_input: str, message_id: str = "user_input_001"):
|
304 |
+
"""从文本生成待办事项列表"""
|
305 |
+
if not PROMPT_TEMPLATE_CONTENT or "Error:" in PROMPT_TEMPLATE_CONTENT:
|
306 |
+
return [["error", "Prompt template not loaded", "-"]]
|
307 |
+
|
308 |
+
current_time_iso = datetime.now(timezone.utc).isoformat()
|
309 |
+
content_escaped = text_input.replace('{', '{{').replace('}', '}}')
|
310 |
+
|
311 |
+
formatted_prompt = PROMPT_TEMPLATE_CONTENT.format(
|
312 |
+
true_positive_examples=TRUE_POSITIVE_EXAMPLES_CONTENT,
|
313 |
+
false_positive_examples=FALSE_POSITIVE_EXAMPLES_CONTENT,
|
314 |
+
current_time=current_time_iso,
|
315 |
+
message_id=message_id,
|
316 |
+
content_escaped=content_escaped
|
317 |
+
)
|
318 |
+
|
319 |
+
enhanced_prompt = formatted_prompt + """
|
320 |
+
|
321 |
+
# 重要提示
|
322 |
+
请确保你的回复是有效的JSON格式,并且只包含JSON内容。不要添加任何额外的解释或文本。
|
323 |
+
你的回复应该严格按照上面的输出示例格式,只包含JSON对象,不要有任何其他文本。
|
324 |
+
"""
|
325 |
+
|
326 |
+
llm_messages = [
|
327 |
+
{"role": "user", "content": enhanced_prompt}
|
328 |
+
]
|
329 |
+
|
330 |
+
try:
|
331 |
+
if ("充值" in text_input or "缴费" in text_input) and ("移动" in text_input or "话费" in text_input or "余额" in text_input):
|
332 |
+
todo_item = {
|
333 |
+
message_id: {
|
334 |
+
"is_todo": True,
|
335 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=3)).isoformat(),
|
336 |
+
"location": "线上:中国移动APP",
|
337 |
+
"todo_content": "缴纳话费",
|
338 |
+
"urgency": "important"
|
339 |
+
}
|
340 |
+
}
|
341 |
+
|
342 |
+
todo_content = "缴纳话费"
|
343 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
344 |
+
location = todo_item[message_id]["location"]
|
345 |
+
|
346 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
347 |
+
|
348 |
+
output_for_df = []
|
349 |
+
output_for_df.append([1, combined_content, "重要"])
|
350 |
+
|
351 |
+
return output_for_df
|
352 |
+
|
353 |
+
elif "会议" in text_input and ("邀请" in text_input or "参加" in text_input):
|
354 |
+
meeting_time = None
|
355 |
+
meeting_pattern = r'(\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2}|\d{4}[年/-]\d{1,2}[月/-]\d{1,2}[日号]?\s*\d{1,2}[点:]\d{0,2})'
|
356 |
+
meeting_match = re.search(meeting_pattern, text_input)
|
357 |
+
|
358 |
+
if meeting_match:
|
359 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1, hours=2)).isoformat()
|
360 |
+
else:
|
361 |
+
meeting_time = (datetime.now(timezone.utc) + timedelta(days=1)).isoformat()
|
362 |
+
|
363 |
+
todo_item = {
|
364 |
+
message_id: {
|
365 |
+
"is_todo": True,
|
366 |
+
"end_time": meeting_time,
|
367 |
+
"location": "线上:会议软件",
|
368 |
+
"todo_content": "参加会议",
|
369 |
+
"urgency": "important"
|
370 |
+
}
|
371 |
+
}
|
372 |
+
|
373 |
+
todo_content = "参加会议"
|
374 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
375 |
+
location = todo_item[message_id]["location"]
|
376 |
+
|
377 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
378 |
+
|
379 |
+
output_for_df = []
|
380 |
+
output_for_df.append([1, combined_content, "重要"])
|
381 |
+
|
382 |
+
return output_for_df
|
383 |
+
|
384 |
+
elif ("快递" in text_input or "物流" in text_input or "取件" in text_input) and ("到达" in text_input or "取件码" in text_input or "柜" in text_input):
|
385 |
+
pickup_code = None
|
386 |
+
code_pattern = r'取件码[是为:]?\s*(\d{4,6})'
|
387 |
+
code_match = re.search(code_pattern, text_input)
|
388 |
+
|
389 |
+
todo_content = "取快递"
|
390 |
+
if code_match:
|
391 |
+
pickup_code = code_match.group(1)
|
392 |
+
todo_content = f"取快递(取件码:{pickup_code})"
|
393 |
+
|
394 |
+
todo_item = {
|
395 |
+
message_id: {
|
396 |
+
"is_todo": True,
|
397 |
+
"end_time": (datetime.now(timezone.utc) + timedelta(days=2)).isoformat(),
|
398 |
+
"location": "线下:快递柜",
|
399 |
+
"todo_content": todo_content,
|
400 |
+
"urgency": "important"
|
401 |
+
}
|
402 |
+
}
|
403 |
+
|
404 |
+
end_time = todo_item[message_id]["end_time"].split("T")[0]
|
405 |
+
location = todo_item[message_id]["location"]
|
406 |
+
|
407 |
+
combined_content = f"{todo_content} (截止时间: {end_time}, 地点: {location})"
|
408 |
+
|
409 |
+
output_for_df = []
|
410 |
+
output_for_df.append([1, combined_content, "重要"])
|
411 |
+
|
412 |
+
return output_for_df
|
413 |
+
|
414 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
415 |
+
return [["error", "Filter API configuration incomplete", "-"]]
|
416 |
+
|
417 |
+
headers = {
|
418 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
419 |
+
"Accept": "application/json"
|
420 |
+
}
|
421 |
+
payload = {
|
422 |
+
"model": Filter_MODEL_NAME,
|
423 |
+
"messages": llm_messages,
|
424 |
+
"temperature": 0.2,
|
425 |
+
"top_p": 0.95,
|
426 |
+
"max_tokens": 1024,
|
427 |
+
"stream": False
|
428 |
+
}
|
429 |
+
|
430 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
431 |
+
|
432 |
+
try:
|
433 |
+
response = requests.post(api_url, headers=headers, json=payload)
|
434 |
+
response.raise_for_status()
|
435 |
+
raw_llm_response = response.json()['choices'][0]['message']['content']
|
436 |
+
except requests.exceptions.RequestException as e:
|
437 |
+
return [["error", f"Filter API call failed: {e}", "-"]]
|
438 |
+
|
439 |
+
parsed_todos_data = json_parser(raw_llm_response)
|
440 |
+
|
441 |
+
if "error" in parsed_todos_data:
|
442 |
+
return [["error", f"LLM response parsing error: {parsed_todos_data['error']}", parsed_todos_data.get('raw_text', '')[:50] + "..."]]
|
443 |
+
|
444 |
+
output_for_df = []
|
445 |
+
|
446 |
+
if isinstance(parsed_todos_data, dict):
|
447 |
+
todo_info = None
|
448 |
+
for key, value in parsed_todos_data.items():
|
449 |
+
if key == message_id or key == str(message_id):
|
450 |
+
todo_info = value
|
451 |
+
break
|
452 |
+
|
453 |
+
if todo_info and isinstance(todo_info, dict) and todo_info.get("is_todo", False):
|
454 |
+
todo_content = todo_info.get("todo_content", "未指定待办内容")
|
455 |
+
end_time = todo_info.get("end_time")
|
456 |
+
location = todo_info.get("location")
|
457 |
+
urgency = todo_info.get("urgency", "unimportant")
|
458 |
+
|
459 |
+
combined_content = todo_content
|
460 |
+
|
461 |
+
if end_time and end_time != "null":
|
462 |
+
try:
|
463 |
+
date_part = end_time.split("T")[0] if "T" in end_time else end_time
|
464 |
+
combined_content += f" (截止时间: {date_part}"
|
465 |
+
except:
|
466 |
+
combined_content += f" (截止时间: {end_time}"
|
467 |
+
else:
|
468 |
+
combined_content += " ("
|
469 |
+
|
470 |
+
if location and location != "null":
|
471 |
+
combined_content += f", 地点: {location})"
|
472 |
+
else:
|
473 |
+
combined_content += ")"
|
474 |
+
|
475 |
+
urgency_display = "一般"
|
476 |
+
if urgency == "urgent":
|
477 |
+
urgency_display = "紧急"
|
478 |
+
elif urgency == "important":
|
479 |
+
urgency_display = "重要"
|
480 |
+
|
481 |
+
output_for_df = []
|
482 |
+
output_for_df.append([1, combined_content, urgency_display])
|
483 |
+
else:
|
484 |
+
output_for_df = []
|
485 |
+
output_for_df.append([1, "此消息不包含待办事项", "-"])
|
486 |
+
|
487 |
+
elif isinstance(parsed_todos_data, list):
|
488 |
+
output_for_df = []
|
489 |
+
|
490 |
+
if not parsed_todos_data:
|
491 |
+
return [[1, "未能生成待办事项", "-"]]
|
492 |
+
|
493 |
+
for i, item in enumerate(parsed_todos_data):
|
494 |
+
if isinstance(item, dict):
|
495 |
+
todo_content = item.get('todo_content', item.get('content', 'N/A'))
|
496 |
+
status = item.get('status', '未完成')
|
497 |
+
urgency = item.get('urgency', 'normal')
|
498 |
+
|
499 |
+
combined_content = todo_content
|
500 |
+
|
501 |
+
if 'end_time' in item and item['end_time']:
|
502 |
+
try:
|
503 |
+
if isinstance(item['end_time'], str):
|
504 |
+
date_part = item['end_time'].split("T")[0] if "T" in item['end_time'] else item['end_time']
|
505 |
+
combined_content += f" (截止时间: {date_part}"
|
506 |
+
else:
|
507 |
+
combined_content += f" (截止时间: {str(item['end_time'])}"
|
508 |
+
except Exception:
|
509 |
+
combined_content += " ("
|
510 |
+
else:
|
511 |
+
combined_content += " ("
|
512 |
+
|
513 |
+
if 'location' in item and item['location']:
|
514 |
+
combined_content += f", 地点: {item['location']})"
|
515 |
+
else:
|
516 |
+
combined_content += ")"
|
517 |
+
|
518 |
+
importance = "一般"
|
519 |
+
if urgency == "urgent":
|
520 |
+
importance = "紧急"
|
521 |
+
elif urgency == "important":
|
522 |
+
importance = "重要"
|
523 |
+
|
524 |
+
output_for_df.append([i + 1, combined_content, importance])
|
525 |
+
else:
|
526 |
+
try:
|
527 |
+
item_str = str(item) if item is not None else "未知项目"
|
528 |
+
output_for_df.append([i + 1, item_str, "一般"])
|
529 |
+
except Exception:
|
530 |
+
output_for_df.append([i + 1, "处理错误的项目", "一般"])
|
531 |
+
|
532 |
+
if not output_for_df:
|
533 |
+
return [["info", "未发现待办事项", "-"]]
|
534 |
+
|
535 |
+
return output_for_df
|
536 |
+
|
537 |
+
except Exception as e:
|
538 |
+
return [["error", f"LLM call/parse error: {str(e)}", "-"]]
|
539 |
+
# 这里------多模态数据从这里调用
|
540 |
+
def process(audio, image):
|
541 |
+
"""处理音频和图片输入,返回基本信息"""
|
542 |
+
if audio is not None:
|
543 |
+
sample_rate, audio_data = audio
|
544 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
545 |
+
else:
|
546 |
+
audio_info = "未收到音频"
|
547 |
+
|
548 |
+
if image is not None:
|
549 |
+
image_info = f"图片尺寸: {image.shape}"
|
550 |
+
else:
|
551 |
+
image_info = "未收到图片"
|
552 |
+
|
553 |
+
return audio_info, image_info
|
554 |
+
|
555 |
+
def respond(message, history, system_message, max_tokens, temperature, top_p, audio, image):
|
556 |
+
"""处理聊天响应,支持流式输出"""
|
557 |
+
chat_messages = [{"role": "system", "content": system_message}]
|
558 |
+
for val in history:
|
559 |
+
if val[0]:
|
560 |
+
chat_messages.append({"role": "user", "content": val[0]})
|
561 |
+
if val[1]:
|
562 |
+
chat_messages.append({"role": "assistant", "content": val[1]})
|
563 |
+
chat_messages.append({"role": "user", "content": message})
|
564 |
+
|
565 |
+
chat_response_stream = ""
|
566 |
+
if not Filter_API_BASE_URL or not Filter_API_KEY or not Filter_MODEL_NAME:
|
567 |
+
yield "Filter API 配置不完整,无法提供聊天回复。", []
|
568 |
+
return
|
569 |
+
|
570 |
+
headers = {
|
571 |
+
"Authorization": f"Bearer {Filter_API_KEY}",
|
572 |
+
"Accept": "application/json"
|
573 |
+
}
|
574 |
+
payload = {
|
575 |
+
"model": Filter_MODEL_NAME,
|
576 |
+
"messages": chat_messages,
|
577 |
+
"temperature": temperature,
|
578 |
+
"top_p": top_p,
|
579 |
+
"max_tokens": max_tokens,
|
580 |
+
"stream": True
|
581 |
+
}
|
582 |
+
api_url = f"{Filter_API_BASE_URL}/chat/completions"
|
583 |
+
|
584 |
+
try:
|
585 |
+
response = requests.post(api_url, headers=headers, json=payload, stream=True)
|
586 |
+
response.raise_for_status()
|
587 |
+
|
588 |
+
for chunk in response.iter_content(chunk_size=None):
|
589 |
+
if chunk:
|
590 |
+
try:
|
591 |
+
for line in chunk.decode('utf-8').splitlines():
|
592 |
+
if line.startswith('data: '):
|
593 |
+
json_data = line[len('data: '):]
|
594 |
+
if json_data.strip() == '[DONE]':
|
595 |
+
break
|
596 |
+
data = json.loads(json_data)
|
597 |
+
token = data['choices'][0]['delta'].get('content', '')
|
598 |
+
if token:
|
599 |
+
chat_response_stream += token
|
600 |
+
yield chat_response_stream, []
|
601 |
+
except json.JSONDecodeError:
|
602 |
+
pass
|
603 |
+
except Exception as e:
|
604 |
+
yield chat_response_stream + f"\n\n错误: {e}", []
|
605 |
+
|
606 |
+
except requests.exceptions.RequestException as e:
|
607 |
+
yield f"调用 NVIDIA API 失败: {e}", []
|
608 |
+
# 图片-多模态上传入口
|
609 |
+
with gr.Blocks() as app:
|
610 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface with ToDo List")
|
611 |
+
|
612 |
+
with gr.Row():
|
613 |
+
with gr.Column(scale=2):
|
614 |
+
gr.Markdown("## Chat Interface")
|
615 |
+
chatbot = gr.Chatbot(height=450, label="聊天记录", type="messages")
|
616 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题或待办事项...")
|
617 |
+
|
618 |
+
with gr.Row():
|
619 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
620 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
621 |
+
|
622 |
+
with gr.Accordion("高级设置", open=False):
|
623 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
624 |
+
max_tokens_slider = gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="最大生成长度(聊天)")
|
625 |
+
temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="温度(聊天)")
|
626 |
+
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p(聊天)")
|
627 |
+
|
628 |
+
with gr.Row():
|
629 |
+
submit_btn = gr.Button("发送", variant="primary")
|
630 |
+
clear_btn = gr.Button("清除聊天和ToDo")
|
631 |
+
|
632 |
+
with gr.Column(scale=1):
|
633 |
+
gr.Markdown("## Generated ToDo List")
|
634 |
+
todolist_df = gr.DataFrame(headers=["ID", "任务内容", "状态"],
|
635 |
+
datatype=["number", "str", "str"],
|
636 |
+
row_count=(0, "dynamic"),
|
637 |
+
col_count=(3, "fixed"),
|
638 |
+
label="待办事项列表")
|
639 |
+
|
640 |
+
def handle_submit(user_msg_content, ch_history, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
641 |
+
"""处理用户提交的消息,生成聊天回复和待办事项"""
|
642 |
+
# 首先处理多模态输入,获取多模态内容
|
643 |
+
multimodal_text_content = ""
|
644 |
+
xunfei_config = get_hf_xunfei_config()
|
645 |
+
xunfei_appid = xunfei_config.get('appid')
|
646 |
+
xunfei_apikey = xunfei_config.get('apikey')
|
647 |
+
xunfei_apisecret = xunfei_config.get('apisecret')
|
648 |
+
|
649 |
+
# 添加调试日志
|
650 |
+
logger.info(f"开始多模态处理 - 音频: {audio_f is not None}, 图像: {image_f is not None}")
|
651 |
+
logger.info(f"讯飞配置状态 - appid: {bool(xunfei_appid)}, apikey: {bool(xunfei_apikey)}, apisecret: {bool(xunfei_apisecret)}")
|
652 |
+
|
653 |
+
# 处理音频输入(独立处理)
|
654 |
+
if audio_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
655 |
+
logger.info("开始处理音频输入...")
|
656 |
+
try:
|
657 |
+
import tempfile
|
658 |
+
import soundfile as sf
|
659 |
+
import os
|
660 |
+
|
661 |
+
audio_sample_rate, audio_data = audio_f
|
662 |
+
logger.info(f"音频信息: 采样率 {audio_sample_rate}Hz, 数据长度 {len(audio_data)}")
|
663 |
+
|
664 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
|
665 |
+
sf.write(temp_audio.name, audio_data, audio_sample_rate)
|
666 |
+
temp_audio_path = temp_audio.name
|
667 |
+
logger.info(f"音频临时文件已保存: {temp_audio_path}")
|
668 |
+
|
669 |
+
audio_text = audio_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_audio_path)
|
670 |
+
logger.info(f"音频识别结果: {audio_text}")
|
671 |
+
if audio_text:
|
672 |
+
multimodal_text_content += f"音频内容: {audio_text}"
|
673 |
+
|
674 |
+
os.unlink(temp_audio_path)
|
675 |
+
logger.info("音频处理完成")
|
676 |
+
except Exception as e:
|
677 |
+
logger.error(f"音频处理错误: {str(e)}")
|
678 |
+
elif audio_f is not None:
|
679 |
+
logger.warning("音频文件存在但讯飞配置不完整,跳过音频处理")
|
680 |
+
|
681 |
+
# 处理图像输入(独立处理)
|
682 |
+
if image_f is not None and xunfei_appid and xunfei_apikey and xunfei_apisecret:
|
683 |
+
logger.info("开始处理图像输入...")
|
684 |
+
try:
|
685 |
+
import tempfile
|
686 |
+
from PIL import Image
|
687 |
+
import os
|
688 |
+
|
689 |
+
logger.info(f"图像信息: 形状 {image_f.shape}, 数据类型 {image_f.dtype}")
|
690 |
+
|
691 |
+
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_image:
|
692 |
+
if len(image_f.shape) == 3: # RGB图像
|
693 |
+
pil_image = Image.fromarray(image_f.astype('uint8'), 'RGB')
|
694 |
+
else: # 灰度图像
|
695 |
+
pil_image = Image.fromarray(image_f.astype('uint8'), 'L')
|
696 |
+
|
697 |
+
pil_image.save(temp_image.name, 'JPEG')
|
698 |
+
temp_image_path = temp_image.name
|
699 |
+
logger.info(f"图像临时文件已保存: {temp_image_path}")
|
700 |
+
|
701 |
+
image_text = image_to_str(xunfei_appid, xunfei_apikey, xunfei_apisecret, temp_image_path)
|
702 |
+
logger.info(f"图像识别结果: {image_text}")
|
703 |
+
if image_text:
|
704 |
+
if multimodal_text_content: # 如果已有音频内容,添加分隔符
|
705 |
+
multimodal_text_content += "\n"
|
706 |
+
multimodal_text_content += f"图像内容: {image_text}"
|
707 |
+
|
708 |
+
os.unlink(temp_image_path)
|
709 |
+
logger.info("图像处理完成")
|
710 |
+
except Exception as e:
|
711 |
+
logger.error(f"图像处理错误: {str(e)}")
|
712 |
+
elif image_f is not None:
|
713 |
+
logger.warning("图像文件存在但讯飞配置不完整,跳过图像处理")
|
714 |
+
|
715 |
+
# 确定最终的用户输入内容:如果用户没有输入文本,使用多模态识别的内容
|
716 |
+
final_user_content = user_msg_content.strip() if user_msg_content else ""
|
717 |
+
if not final_user_content and multimodal_text_content:
|
718 |
+
final_user_content = multimodal_text_content
|
719 |
+
logger.info(f"用户无文本输入,使用多模态内容作为用户输入: {final_user_content}")
|
720 |
+
elif final_user_content and multimodal_text_content:
|
721 |
+
# 用户有文本输入,多模态内容作为补充
|
722 |
+
final_user_content = f"{final_user_content}\n{multimodal_text_content}"
|
723 |
+
logger.info(f"用户有文本输入,多模态内容作为补充")
|
724 |
+
|
725 |
+
# 如果最终还是没有任何内容,提供默认提示
|
726 |
+
if not final_user_content:
|
727 |
+
final_user_content = "[无输入内容]"
|
728 |
+
logger.warning("用户没有提供任何输入内容(文本、音频或图像)")
|
729 |
+
|
730 |
+
logger.info(f"最终用户输入内容: {final_user_content}")
|
731 |
+
|
732 |
+
# 1. 更新聊天记录 (用户部分) - 使用最终确定的用户内容
|
733 |
+
if not ch_history: ch_history = []
|
734 |
+
ch_history.append({"role": "user", "content": final_user_content})
|
735 |
+
yield ch_history, []
|
736 |
+
|
737 |
+
# 2. 流式生成机器人回复并更新聊天记录
|
738 |
+
formatted_hist_for_respond = []
|
739 |
+
temp_user_msg_for_hist = None
|
740 |
+
for item_hist in ch_history[:-1]:
|
741 |
+
if item_hist["role"] == "user":
|
742 |
+
temp_user_msg_for_hist = item_hist["content"]
|
743 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is not None:
|
744 |
+
formatted_hist_for_respond.append((temp_user_msg_for_hist, item_hist["content"]))
|
745 |
+
temp_user_msg_for_hist = None
|
746 |
+
elif item_hist["role"] == "assistant" and temp_user_msg_for_hist is None:
|
747 |
+
formatted_hist_for_respond.append(("", item_hist["content"]))
|
748 |
+
|
749 |
+
ch_history.append({"role": "assistant", "content": ""})
|
750 |
+
|
751 |
+
full_bot_response = ""
|
752 |
+
# 使用最终确定的用户内容进行对话
|
753 |
+
for bot_response_token, _ in respond(final_user_content, formatted_hist_for_respond, sys_msg, max_t, temp, t_p, audio_f, image_f):
|
754 |
+
full_bot_response = bot_response_token
|
755 |
+
ch_history[-1]["content"] = full_bot_response
|
756 |
+
yield ch_history, []
|
757 |
+
|
758 |
+
# 3. 生成 ToDoList - 使用最终确定的用户内容
|
759 |
+
text_for_todo = final_user_content
|
760 |
+
|
761 |
+
# 添加日志:输出用于ToDo生成的内容
|
762 |
+
logger.info(f"用于ToDo生成的内容: {text_for_todo}")
|
763 |
+
current_todos_list = []
|
764 |
+
|
765 |
+
filtered_result = filter_message_with_llm(text_for_todo)
|
766 |
+
|
767 |
+
if isinstance(filtered_result, dict) and "error" in filtered_result:
|
768 |
+
current_todos_list = [["Error", filtered_result['error'], "Filter Failed"]]
|
769 |
+
elif isinstance(filtered_result, dict) and filtered_result.get("分类") == "其他":
|
770 |
+
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
771 |
+
elif isinstance(filtered_result, list):
|
772 |
+
category = None
|
773 |
+
|
774 |
+
if not filtered_result:
|
775 |
+
if text_for_todo:
|
776 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
777 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
778 |
+
yield ch_history, current_todos_list
|
779 |
+
return
|
780 |
+
|
781 |
+
valid_item = None
|
782 |
+
for item in filtered_result:
|
783 |
+
if isinstance(item, dict):
|
784 |
+
valid_item = item
|
785 |
+
if "分类" in item:
|
786 |
+
category = item["分类"]
|
787 |
+
break
|
788 |
+
|
789 |
+
if valid_item is None:
|
790 |
+
if text_for_todo:
|
791 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
792 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
793 |
+
yield ch_history, current_todos_list
|
794 |
+
return
|
795 |
+
|
796 |
+
if category == "其他":
|
797 |
+
current_todos_list = [["Info", "消息被归类为 '其他',无需生成 ToDo。", "Filtered"]]
|
798 |
+
else:
|
799 |
+
if text_for_todo:
|
800 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
801 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
802 |
+
else:
|
803 |
+
if text_for_todo:
|
804 |
+
msg_id_todo = f"hf_app_todo_{datetime.now().strftime('%Y%m%d%H%M%S%f')}"
|
805 |
+
current_todos_list = generate_todolist_from_text(text_for_todo, msg_id_todo)
|
806 |
+
|
807 |
+
yield ch_history, current_todos_list
|
808 |
+
|
809 |
+
submit_btn.click(
|
810 |
+
handle_submit,
|
811 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
812 |
+
[chatbot, todolist_df]
|
813 |
+
)
|
814 |
+
msg.submit(
|
815 |
+
handle_submit,
|
816 |
+
[msg, chatbot, system_msg, max_tokens_slider, temperature_slider, top_p_slider, audio_input, image_input],
|
817 |
+
[chatbot, todolist_df]
|
818 |
+
)
|
819 |
+
|
820 |
+
def clear_all():
|
821 |
+
"""清除所有聊天记录和待办事项"""
|
822 |
+
return None, None, ""
|
823 |
+
clear_btn.click(clear_all, None, [chatbot, todolist_df, msg], queue=False)
|
824 |
+
#多模态标签也
|
825 |
+
with gr.Tab("Audio/Image Processing (Original)"):
|
826 |
+
gr.Markdown("## 处理音频和图片")
|
827 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
828 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
829 |
+
process_btn = gr.Button("处理", variant="primary")
|
830 |
+
audio_output = gr.Textbox(label="音频信息")
|
831 |
+
image_output = gr.Textbox(label="图片信息")
|
832 |
+
|
833 |
+
process_btn.click(
|
834 |
+
process,
|
835 |
+
inputs=[audio_processor, image_processor],
|
836 |
+
outputs=[audio_output, image_output]
|
837 |
+
)
|
838 |
+
|
839 |
+
if __name__ == "__main__":
|
840 |
+
app.launch(debug=False)
|
audio_127.0.0.1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c4cca96c289e5acdfd9d8e926bb40674e170374878d57e4d3c3f5aca3039bec8
|
3 |
+
size 1830956
|
image_127.0.0.1.jpg
ADDED
![]() |
requirements.txt
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
-
gradio
|
2 |
-
requests
|
3 |
-
pathlib
|
4 |
-
python-dateutil
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
requests
|
3 |
+
pathlib
|
4 |
+
python-dateutil
|
5 |
+
Pillow
|
6 |
+
numpy
|
7 |
+
wave
|
8 |
+
azure-ai-inference
|
se_app.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from huggingface_hub import InferenceClient
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from scipy.io.wavfile import write as write_wav
|
6 |
+
from PIL import Image
|
7 |
+
from tools import audio_to_str, image_to_str # 导入tools.py中的方法
|
8 |
+
|
9 |
+
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
10 |
+
|
11 |
+
# 指定保存文件的相对路径
|
12 |
+
SAVE_DIR = 'download' # 相对路径
|
13 |
+
os.makedirs(SAVE_DIR, exist_ok=True) # 确保目录存在
|
14 |
+
|
15 |
+
def get_client_ip(request: gr.Request, debug_mode=False):
|
16 |
+
"""获取客户端真实IP地址"""
|
17 |
+
if request:
|
18 |
+
# 从请求头中获取真实IP(考虑代理情况)
|
19 |
+
x_forwarded_for = request.headers.get("x-forwarded-for", "")
|
20 |
+
if x_forwarded_for:
|
21 |
+
client_ip = x_forwarded_for.split(",")[0]
|
22 |
+
else:
|
23 |
+
client_ip = request.client.host
|
24 |
+
if debug_mode:
|
25 |
+
print(f"Debug: Client IP detected as {client_ip}")
|
26 |
+
return client_ip
|
27 |
+
return "unknown"
|
28 |
+
|
29 |
+
def save_audio(audio, filename):
|
30 |
+
"""保存音频为.wav文件"""
|
31 |
+
sample_rate, audio_data = audio
|
32 |
+
write_wav(filename, sample_rate, audio_data)
|
33 |
+
|
34 |
+
def save_image(image, filename):
|
35 |
+
"""保存图片为.jpg文件"""
|
36 |
+
img = Image.fromarray(image.astype('uint8'))
|
37 |
+
img.save(filename)
|
38 |
+
|
39 |
+
def process(audio, image, text, request: gr.Request):
|
40 |
+
"""处理语音、图片和文本的示例函数"""
|
41 |
+
client_ip = get_client_ip(request, True)
|
42 |
+
print(f"Processing request from IP: {client_ip}")
|
43 |
+
|
44 |
+
audio_info = "未收到音频"
|
45 |
+
image_info = "未收到图片"
|
46 |
+
text_info = "未收到文本"
|
47 |
+
audio_filename = None
|
48 |
+
image_filename = None
|
49 |
+
audio_text = ""
|
50 |
+
image_text = ""
|
51 |
+
|
52 |
+
if audio is not None:
|
53 |
+
sample_rate, audio_data = audio
|
54 |
+
audio_info = f"音频采样率: {sample_rate}Hz, 数据长度: {len(audio_data)}"
|
55 |
+
# 保存音频为.wav文件
|
56 |
+
audio_filename = os.path.join(SAVE_DIR, f"audio_{client_ip}.wav")
|
57 |
+
save_audio(audio, audio_filename)
|
58 |
+
print(f"Audio saved as {audio_filename}")
|
59 |
+
# 调用tools.py中的audio_to_str方法处理音频
|
60 |
+
audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
|
61 |
+
if audio_text:
|
62 |
+
print(f"Audio text: {audio_text}")
|
63 |
+
else:
|
64 |
+
print("Audio processing failed")
|
65 |
+
|
66 |
+
if image is not None:
|
67 |
+
image_info = f"图片尺寸: {image.shape}"
|
68 |
+
# 保存图片为.jpg文件
|
69 |
+
image_filename = os.path.join(SAVE_DIR, f"image_{client_ip}.jpg")
|
70 |
+
save_image(image, image_filename)
|
71 |
+
print(f"Image saved as {image_filename}")
|
72 |
+
# 调用tools.py中的image_to_str方法处理图片
|
73 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
74 |
+
if image_text:
|
75 |
+
print(f"Image text: {image_text}")
|
76 |
+
else:
|
77 |
+
print("Image processing failed")
|
78 |
+
|
79 |
+
if text:
|
80 |
+
text_info = f"接收到文本: {text}"
|
81 |
+
|
82 |
+
return audio_info, image_info, text_info, audio_text, image_text
|
83 |
+
|
84 |
+
# 创建自定义的聊天界面
|
85 |
+
with gr.Blocks() as app:
|
86 |
+
gr.Markdown("# ToDoAgent Multi-Modal Interface")
|
87 |
+
|
88 |
+
# 创建两个标签页
|
89 |
+
with gr.Tab("Chat"):
|
90 |
+
# 修复Chatbot类型警告
|
91 |
+
chatbot = gr.Chatbot(height=500, type="messages")
|
92 |
+
|
93 |
+
msg = gr.Textbox(label="输入消息", placeholder="输入您的问题...")
|
94 |
+
|
95 |
+
# 上传区域
|
96 |
+
with gr.Row():
|
97 |
+
audio_input = gr.Audio(label="上传语音", type="numpy", sources=["upload", "microphone"])
|
98 |
+
image_input = gr.Image(label="上传图片", type="numpy")
|
99 |
+
|
100 |
+
# 设置区域
|
101 |
+
with gr.Accordion("高级设置", open=False):
|
102 |
+
system_msg = gr.Textbox(value="You are a friendly Chatbot.", label="系统提示")
|
103 |
+
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="最大生成长度")
|
104 |
+
temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="温度")
|
105 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
|
106 |
+
|
107 |
+
# 提交按钮
|
108 |
+
submit_btn = gr.Button("发送", variant="primary")
|
109 |
+
|
110 |
+
# 清除按钮
|
111 |
+
clear = gr.Button("清除聊天")
|
112 |
+
|
113 |
+
# 事件处理
|
114 |
+
def user(user_message, chat_history):
|
115 |
+
return "", chat_history + [{"role": "user", "content": user_message}]
|
116 |
+
#新增多模态处理--1
|
117 |
+
def respond(message, chat_history, system_message, max_tokens, temperature, top_p, audio=None, image=None, text=None, request=None):
|
118 |
+
"""生成响应的函数"""
|
119 |
+
# 处理多模态输入
|
120 |
+
multimodal_content = ""
|
121 |
+
if audio is not None:
|
122 |
+
try:
|
123 |
+
audio_filename = os.path.join(SAVE_DIR, "temp_audio.wav")
|
124 |
+
save_audio(audio, audio_filename)
|
125 |
+
audio_text = audio_to_str("33c1b63d", "40bf7cd82e31ace30a9cfb76309a43a3", "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4", audio_filename)
|
126 |
+
if audio_text:
|
127 |
+
multimodal_content += f"音频内容: {audio_text}\n"
|
128 |
+
except Exception as e:
|
129 |
+
print(f"Audio processing error: {e}")
|
130 |
+
|
131 |
+
if image is not None:
|
132 |
+
try:
|
133 |
+
image_filename = os.path.join(SAVE_DIR, "temp_image.jpg")
|
134 |
+
save_image(image, image_filename)
|
135 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_filename)
|
136 |
+
if image_text:
|
137 |
+
multimodal_content += f"图片内容: {image_text}\n"
|
138 |
+
except Exception as e:
|
139 |
+
print(f"Image processing error: {e}")
|
140 |
+
|
141 |
+
# 组合最终消息
|
142 |
+
final_message = message
|
143 |
+
if multimodal_content:
|
144 |
+
final_message = f"{message}\n\n{multimodal_content}"
|
145 |
+
|
146 |
+
# 构建消息历史
|
147 |
+
messages = [{"role": "system", "content": system_message}]
|
148 |
+
for chat in chat_history:
|
149 |
+
if isinstance(chat, dict) and "role" in chat and "content" in chat:
|
150 |
+
messages.append(chat)
|
151 |
+
|
152 |
+
messages.append({"role": "user", "content": final_message})
|
153 |
+
|
154 |
+
# 调用HuggingFace API
|
155 |
+
try:
|
156 |
+
response = client.chat_completion(
|
157 |
+
messages,
|
158 |
+
max_tokens=max_tokens,
|
159 |
+
stream=True,
|
160 |
+
temperature=temperature,
|
161 |
+
top_p=top_p,
|
162 |
+
)
|
163 |
+
|
164 |
+
partial_message = ""
|
165 |
+
for token in response:
|
166 |
+
if token.choices[0].delta.content is not None:
|
167 |
+
partial_message += token.choices[0].delta.content
|
168 |
+
yield partial_message
|
169 |
+
except Exception as e:
|
170 |
+
yield f"抱歉,生成响应时出现错误: {str(e)}"
|
171 |
+
|
172 |
+
def bot(chat_history, system_message, max_tokens, temperature, top_p, audio, image, text):
|
173 |
+
# 检查chat_history是否为空
|
174 |
+
if not chat_history or len(chat_history) == 0:
|
175 |
+
return
|
176 |
+
|
177 |
+
# 获取最后一条用户消息
|
178 |
+
last_message = chat_history[-1]
|
179 |
+
if not last_message or not isinstance(last_message, dict) or "content" not in last_message:
|
180 |
+
return
|
181 |
+
|
182 |
+
user_message = last_message["content"]
|
183 |
+
|
184 |
+
# 生成响应
|
185 |
+
bot_response = ""
|
186 |
+
for response in respond(
|
187 |
+
user_message,
|
188 |
+
chat_history[:-1],
|
189 |
+
system_message,
|
190 |
+
max_tokens,
|
191 |
+
temperature,
|
192 |
+
top_p,
|
193 |
+
audio,
|
194 |
+
image,
|
195 |
+
text
|
196 |
+
):
|
197 |
+
bot_response = response
|
198 |
+
# 添加助手回复到聊天历史
|
199 |
+
updated_history = chat_history + [{"role": "assistant", "content": bot_response}]
|
200 |
+
yield updated_history
|
201 |
+
|
202 |
+
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
203 |
+
bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
|
204 |
+
)
|
205 |
+
|
206 |
+
submit_btn.click(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
207 |
+
bot, [chatbot, system_msg, max_tokens, temperature, top_p, audio_input, image_input, msg], chatbot
|
208 |
+
)
|
209 |
+
|
210 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
211 |
+
|
212 |
+
with gr.Tab("Audio/Image Processing"):
|
213 |
+
gr.Markdown("## 处理音频和图片")
|
214 |
+
audio_processor = gr.Audio(label="上传音频", type="numpy")
|
215 |
+
image_processor = gr.Image(label="上传图片", type="numpy")
|
216 |
+
text_input = gr.Textbox(label="输入文本")
|
217 |
+
process_btn = gr.Button("处理", variant="primary")
|
218 |
+
audio_output = gr.Textbox(label="音频信息")
|
219 |
+
image_output = gr.Textbox(label="图片信息")
|
220 |
+
text_output = gr.Textbox(label="文本信息")
|
221 |
+
audio_text_output = gr.Textbox(label="音频转文字结果")
|
222 |
+
image_text_output = gr.Textbox(label="图片转文字结果")
|
223 |
+
|
224 |
+
# 修改后的处理函数调用
|
225 |
+
process_btn.click(
|
226 |
+
process,
|
227 |
+
inputs=[audio_processor, image_processor, text_input],
|
228 |
+
outputs=[audio_output, image_output, text_output, audio_text_output, image_text_output]
|
229 |
+
)
|
230 |
+
|
231 |
+
if __name__ == "__main__":
|
232 |
+
app.launch()
|
temp_audio.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8a873051a6c784789c314ab829772eac2446271337f54d58db48e921e81ab71e
|
3 |
+
size 710700
|
todogen_LLM_config.yaml
CHANGED
@@ -38,4 +38,14 @@ HF_CONFIG_PATH:
|
|
38 |
openai_filter:
|
39 |
base_url_filter: https://aihubmix.com/v1
|
40 |
api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
|
41 |
-
model_filter: gpt-4o-mini
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
openai_filter:
|
39 |
base_url_filter: https://aihubmix.com/v1
|
40 |
api_key_filter: sk-BSNyITzJBSSgfFdJ792b66C7789c479cA7Ec1e36FfB343A1
|
41 |
+
model_filter: gpt-4o-mini
|
42 |
+
|
43 |
+
xunfei:
|
44 |
+
appid: 33c1b63d
|
45 |
+
apikey: 40bf7cd82e31ace30a9cfb76309a43a3
|
46 |
+
apisecret: OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4
|
47 |
+
|
48 |
+
azure_speech:
|
49 |
+
key: 45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ
|
50 |
+
region: eastus2
|
51 |
+
endpoint: https://eastus2.stt.speech.microsoft.com
|
tools.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding:utf-8 -*-
|
3 |
+
import os
|
4 |
+
import datetime
|
5 |
+
import re
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
import math
|
9 |
+
from urllib.parse import urlparse
|
10 |
+
from urllib3 import encode_multipart_formdata
|
11 |
+
from wsgiref.handlers import format_date_time
|
12 |
+
from time import mktime
|
13 |
+
import hashlib
|
14 |
+
import base64
|
15 |
+
import hmac
|
16 |
+
from urllib.parse import urlencode
|
17 |
+
import json
|
18 |
+
import requests
|
19 |
+
import azure.cognitiveservices.speech as speechsdk
|
20 |
+
|
21 |
+
# 常量定义
|
22 |
+
LFASR_HOST = "http://upload-ost-api.xfyun.cn/file" # 文件上传Host
|
23 |
+
API_INIT = "/mpupload/init" # 初始化接口
|
24 |
+
API_UPLOAD = "/upload" # 上传接口
|
25 |
+
API_CUT = "/mpupload/upload" # 分片上传接口
|
26 |
+
API_CUT_COMPLETE = "/mpupload/complete" # 分片完成接口
|
27 |
+
API_CUT_CANCEL = "/mpupload/cancel" # 分片取消接口
|
28 |
+
FILE_PIECE_SIZE = 5242880 # 文件分片大小5M
|
29 |
+
PRO_CREATE_URI = "/v2/ost/pro_create"
|
30 |
+
QUERY_URI = "/v2/ost/query"
|
31 |
+
|
32 |
+
|
33 |
+
# 文件上传类
|
34 |
+
class FileUploader:
|
35 |
+
def __init__(self, app_id, api_key, api_secret, upload_file_path):
|
36 |
+
self.app_id = app_id
|
37 |
+
self.api_key = api_key
|
38 |
+
self.api_secret = api_secret
|
39 |
+
self.upload_file_path = upload_file_path
|
40 |
+
|
41 |
+
def get_request_id(self):
|
42 |
+
"""生成请求ID"""
|
43 |
+
return time.strftime("%Y%m%d%H%M")
|
44 |
+
|
45 |
+
def hashlib_256(self, data):
|
46 |
+
"""计算 SHA256 哈希"""
|
47 |
+
m = hashlib.sha256(bytes(data.encode(encoding="utf-8"))).digest()
|
48 |
+
digest = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
|
49 |
+
return digest
|
50 |
+
|
51 |
+
def assemble_auth_header(self, request_url, file_data_type, method="", body=""):
|
52 |
+
"""组装鉴权头部"""
|
53 |
+
u = urlparse(request_url)
|
54 |
+
host = u.hostname
|
55 |
+
path = u.path
|
56 |
+
now = datetime.datetime.now()
|
57 |
+
date = format_date_time(mktime(now.timetuple()))
|
58 |
+
digest = "SHA256=" + self.hashlib_256("")
|
59 |
+
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1\ndigest: {}".format(
|
60 |
+
host, date, method, path, digest
|
61 |
+
)
|
62 |
+
signature_sha = hmac.new(
|
63 |
+
self.api_secret.encode("utf-8"),
|
64 |
+
signature_origin.encode("utf-8"),
|
65 |
+
digestmod=hashlib.sha256,
|
66 |
+
).digest()
|
67 |
+
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
68 |
+
authorization = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
69 |
+
self.api_key,
|
70 |
+
"hmac-sha256",
|
71 |
+
"host date request-line digest",
|
72 |
+
signature_sha,
|
73 |
+
)
|
74 |
+
headers = {
|
75 |
+
"host": host,
|
76 |
+
"date": date,
|
77 |
+
"authorization": authorization,
|
78 |
+
"digest": digest,
|
79 |
+
"content-type": file_data_type,
|
80 |
+
}
|
81 |
+
return headers
|
82 |
+
|
83 |
+
def call_api(self, url, file_data, file_data_type):
|
84 |
+
"""调用POST API接口"""
|
85 |
+
headers = self.assemble_auth_header(
|
86 |
+
url, file_data_type, method="POST", body=file_data
|
87 |
+
)
|
88 |
+
try:
|
89 |
+
resp = requests.post(url, headers=headers, data=file_data, timeout=8)
|
90 |
+
print("上传状态:", resp.status_code, resp.text)
|
91 |
+
return resp.json()
|
92 |
+
except Exception as e:
|
93 |
+
print("上传失败!Exception :%s" % e)
|
94 |
+
return None
|
95 |
+
|
96 |
+
def upload_cut_complete(self, upload_id):
|
97 |
+
"""分块上传完成"""
|
98 |
+
body_dict = {
|
99 |
+
"app_id": self.app_id,
|
100 |
+
"request_id": self.get_request_id(),
|
101 |
+
"upload_id": upload_id,
|
102 |
+
}
|
103 |
+
file_data_type = "application/json"
|
104 |
+
url = LFASR_HOST + API_CUT_COMPLETE
|
105 |
+
response = self.call_api(url, json.dumps(body_dict), file_data_type)
|
106 |
+
if response and "data" in response and "url" in response["data"]:
|
107 |
+
file_url = response["data"]["url"]
|
108 |
+
print("任务上传结束")
|
109 |
+
return file_url
|
110 |
+
else:
|
111 |
+
print("分片上传完成失败", response)
|
112 |
+
return None
|
113 |
+
|
114 |
+
def upload_file(self):
|
115 |
+
"""上传文件,根据文件大小选择分片或普通上传"""
|
116 |
+
file_total_size = os.path.getsize(self.upload_file_path)
|
117 |
+
if file_total_size < 31457280: # 30MB
|
118 |
+
print("-----不使用分块上传-----")
|
119 |
+
return self.simple_upload()
|
120 |
+
else:
|
121 |
+
print("-----使用分块上传-----")
|
122 |
+
return self.multipart_upload()
|
123 |
+
|
124 |
+
def simple_upload(self):
|
125 |
+
"""简单上传文件"""
|
126 |
+
try:
|
127 |
+
with open(self.upload_file_path, mode="rb") as f:
|
128 |
+
file = {
|
129 |
+
"data": (self.upload_file_path, f.read()),
|
130 |
+
"app_id": self.app_id,
|
131 |
+
"request_id": self.get_request_id(),
|
132 |
+
}
|
133 |
+
encode_data = encode_multipart_formdata(file)
|
134 |
+
file_data = encode_data[0]
|
135 |
+
file_data_type = encode_data[1]
|
136 |
+
url = LFASR_HOST + API_UPLOAD
|
137 |
+
response = self.call_api(url, file_data, file_data_type)
|
138 |
+
if response and "data" in response and "url" in response["data"]:
|
139 |
+
return response["data"]["url"]
|
140 |
+
else:
|
141 |
+
print("简单上传失败", response)
|
142 |
+
return None
|
143 |
+
except FileNotFoundError:
|
144 |
+
print("文件未找到:", self.upload_file_path)
|
145 |
+
return None
|
146 |
+
|
147 |
+
def multipart_upload(self):
|
148 |
+
"""分片上传文件"""
|
149 |
+
upload_id = self.prepare_upload()
|
150 |
+
if not upload_id:
|
151 |
+
return None
|
152 |
+
|
153 |
+
if not self.do_upload(upload_id):
|
154 |
+
return None
|
155 |
+
|
156 |
+
file_url = self.upload_cut_complete(upload_id)
|
157 |
+
print("分片上传地址:", file_url)
|
158 |
+
return file_url
|
159 |
+
|
160 |
+
def prepare_upload(self):
|
161 |
+
"""预处理,获取upload_id"""
|
162 |
+
body_dict = {
|
163 |
+
"app_id": self.app_id,
|
164 |
+
"request_id": self.get_request_id(),
|
165 |
+
}
|
166 |
+
url = LFASR_HOST + API_INIT
|
167 |
+
file_data_type = "application/json"
|
168 |
+
response = self.call_api(url, json.dumps(body_dict), file_data_type)
|
169 |
+
if response and "data" in response and "upload_id" in response["data"]:
|
170 |
+
return response["data"]["upload_id"]
|
171 |
+
else:
|
172 |
+
print("预处理失败", response)
|
173 |
+
return None
|
174 |
+
|
175 |
+
def do_upload(self, upload_id):
|
176 |
+
"""执行分片上传"""
|
177 |
+
file_total_size = os.path.getsize(self.upload_file_path)
|
178 |
+
chunk_size = FILE_PIECE_SIZE
|
179 |
+
chunks = math.ceil(file_total_size / chunk_size)
|
180 |
+
request_id = self.get_request_id()
|
181 |
+
slice_id = 1
|
182 |
+
|
183 |
+
print(
|
184 |
+
"文件:",
|
185 |
+
self.upload_file_path,
|
186 |
+
" 文件大小:",
|
187 |
+
file_total_size,
|
188 |
+
" 分块大小:",
|
189 |
+
chunk_size,
|
190 |
+
" 分块数:",
|
191 |
+
chunks,
|
192 |
+
)
|
193 |
+
|
194 |
+
with open(self.upload_file_path, mode="rb") as content:
|
195 |
+
while slice_id <= chunks:
|
196 |
+
current_size = min(
|
197 |
+
chunk_size, file_total_size - (slice_id - 1) * chunk_size
|
198 |
+
)
|
199 |
+
|
200 |
+
file = {
|
201 |
+
"data": (self.upload_file_path, content.read(current_size)),
|
202 |
+
"app_id": self.app_id,
|
203 |
+
"request_id": request_id,
|
204 |
+
"upload_id": upload_id,
|
205 |
+
"slice_id": slice_id,
|
206 |
+
}
|
207 |
+
|
208 |
+
encode_data = encode_multipart_formdata(file)
|
209 |
+
file_data = encode_data[0]
|
210 |
+
file_data_type = encode_data[1]
|
211 |
+
url = LFASR_HOST + API_CUT
|
212 |
+
|
213 |
+
resp = self.call_api(url, file_data, file_data_type)
|
214 |
+
count = 0
|
215 |
+
while not resp and (count < 3):
|
216 |
+
print("上传重试")
|
217 |
+
resp = self.call_api(url, file_data, file_data_type)
|
218 |
+
count = count + 1
|
219 |
+
time.sleep(1)
|
220 |
+
if not resp:
|
221 |
+
print("分片上传失败")
|
222 |
+
return False
|
223 |
+
slice_id += 1
|
224 |
+
|
225 |
+
return True
|
226 |
+
|
227 |
+
|
228 |
+
class ResultExtractor:
|
229 |
+
def __init__(self, appid, apikey, apisecret):
|
230 |
+
# POST 请求相关参数
|
231 |
+
self.Host = "ost-api.xfyun.cn"
|
232 |
+
self.RequestUriCreate = PRO_CREATE_URI
|
233 |
+
self.RequestUriQuery = QUERY_URI
|
234 |
+
# 设置 URL
|
235 |
+
if re.match(r"^\d", self.Host):
|
236 |
+
self.urlCreate = "http://" + self.Host + self.RequestUriCreate
|
237 |
+
self.urlQuery = "http://" + self.Host + self.RequestUriQuery
|
238 |
+
else:
|
239 |
+
self.urlCreate = "https://" + self.Host + self.RequestUriCreate
|
240 |
+
self.urlQuery = "https://" + self.Host + self.RequestUriQuery
|
241 |
+
self.HttpMethod = "POST"
|
242 |
+
self.APPID = appid
|
243 |
+
self.Algorithm = "hmac-sha256"
|
244 |
+
self.HttpProto = "HTTP/1.1"
|
245 |
+
self.UserName = apikey
|
246 |
+
self.Secret = apisecret
|
247 |
+
|
248 |
+
# 设置当前时间
|
249 |
+
cur_time_utc = datetime.datetime.now(datetime.timezone.utc)
|
250 |
+
self.Date = self.httpdate(cur_time_utc)
|
251 |
+
|
252 |
+
# 设置测试音频文件参数
|
253 |
+
self.BusinessArgsCreate = {
|
254 |
+
"language": "zh_cn",
|
255 |
+
"accent": "mandarin",
|
256 |
+
"domain": "pro_ost_ed",
|
257 |
+
}
|
258 |
+
|
259 |
+
def img_read(self, path):
|
260 |
+
with open(path, "rb") as fo:
|
261 |
+
return fo.read()
|
262 |
+
|
263 |
+
def hashlib_256(self, res):
|
264 |
+
m = hashlib.sha256(bytes(res.encode(encoding="utf-8"))).digest()
|
265 |
+
result = "SHA-256=" + base64.b64encode(m).decode(encoding="utf-8")
|
266 |
+
return result
|
267 |
+
|
268 |
+
def httpdate(self, dt):
|
269 |
+
weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"][dt.weekday()]
|
270 |
+
month = [
|
271 |
+
"Jan",
|
272 |
+
"Feb",
|
273 |
+
"Mar",
|
274 |
+
"Apr",
|
275 |
+
"May",
|
276 |
+
"Jun",
|
277 |
+
"Jul",
|
278 |
+
"Aug",
|
279 |
+
"Sep",
|
280 |
+
"Oct",
|
281 |
+
"Nov",
|
282 |
+
"Dec",
|
283 |
+
][dt.month - 1]
|
284 |
+
return "%s, %02d %s %04d %02d:%02d:%02d GMT" % (
|
285 |
+
weekday,
|
286 |
+
dt.day,
|
287 |
+
month,
|
288 |
+
dt.year,
|
289 |
+
dt.hour,
|
290 |
+
dt.minute,
|
291 |
+
dt.second,
|
292 |
+
)
|
293 |
+
|
294 |
+
def generateSignature(self, digest, uri):
|
295 |
+
signature_str = "host: " + self.Host + "\n"
|
296 |
+
signature_str += "date: " + self.Date + "\n"
|
297 |
+
signature_str += self.HttpMethod + " " + uri + " " + self.HttpProto + "\n"
|
298 |
+
signature_str += "digest: " + digest
|
299 |
+
signature = hmac.new(
|
300 |
+
bytes(self.Secret.encode("utf-8")),
|
301 |
+
bytes(signature_str.encode("utf-8")),
|
302 |
+
digestmod=hashlib.sha256,
|
303 |
+
).digest()
|
304 |
+
result = base64.b64encode(signature)
|
305 |
+
return result.decode(encoding="utf-8")
|
306 |
+
|
307 |
+
def init_header(self, data, uri):
|
308 |
+
digest = self.hashlib_256(data)
|
309 |
+
sign = self.generateSignature(digest, uri)
|
310 |
+
auth_header = (
|
311 |
+
'api_key="%s",algorithm="%s", '
|
312 |
+
'headers="host date request-line digest", '
|
313 |
+
'signature="%s"' % (self.UserName, self.Algorithm, sign)
|
314 |
+
)
|
315 |
+
headers = {
|
316 |
+
"Content-Type": "application/json",
|
317 |
+
"Accept": "application/json",
|
318 |
+
"Method": "POST",
|
319 |
+
"Host": self.Host,
|
320 |
+
"Date": self.Date,
|
321 |
+
"Digest": digest,
|
322 |
+
"Authorization": auth_header,
|
323 |
+
}
|
324 |
+
return headers
|
325 |
+
|
326 |
+
def get_create_body(self, fileurl):
|
327 |
+
post_data = {
|
328 |
+
"common": {"app_id": self.APPID},
|
329 |
+
"business": self.BusinessArgsCreate,
|
330 |
+
"data": {"audio_src": "http", "audio_url": fileurl, "encoding": "raw"},
|
331 |
+
}
|
332 |
+
body = json.dumps(post_data)
|
333 |
+
return body
|
334 |
+
|
335 |
+
def get_query_body(self, task_id):
|
336 |
+
post_data = {
|
337 |
+
"common": {"app_id": self.APPID},
|
338 |
+
"business": {
|
339 |
+
"task_id": task_id,
|
340 |
+
},
|
341 |
+
}
|
342 |
+
body = json.dumps(post_data)
|
343 |
+
return body
|
344 |
+
|
345 |
+
def call(self, url, body, headers):
|
346 |
+
try:
|
347 |
+
response = requests.post(url, data=body, headers=headers, timeout=8)
|
348 |
+
status_code = response.status_code
|
349 |
+
if status_code != 200:
|
350 |
+
info = response.content
|
351 |
+
return info
|
352 |
+
else:
|
353 |
+
try:
|
354 |
+
return json.loads(response.text)
|
355 |
+
except json.JSONDecodeError:
|
356 |
+
return response.text
|
357 |
+
except Exception as e:
|
358 |
+
print("Exception :%s" % e)
|
359 |
+
return None
|
360 |
+
|
361 |
+
def task_create(self, fileurl):
|
362 |
+
body = self.get_create_body(fileurl)
|
363 |
+
headers_create = self.init_header(body, self.RequestUriCreate)
|
364 |
+
return self.call(self.urlCreate, body, headers_create)
|
365 |
+
|
366 |
+
def task_query(self, task_id):
|
367 |
+
query_body = self.get_query_body(task_id)
|
368 |
+
headers_query = self.init_header(query_body, self.RequestUriQuery)
|
369 |
+
return self.call(self.urlQuery, query_body, headers_query)
|
370 |
+
|
371 |
+
def extract_text(self, result):
|
372 |
+
"""
|
373 |
+
从API响应中提取文本内容
|
374 |
+
支持多种结果格式,增强错误处理
|
375 |
+
"""
|
376 |
+
# 调试输出:打印原始结果类型
|
377 |
+
print(f"\n[DEBUG] extract_text 输入类型: {type(result)}")
|
378 |
+
|
379 |
+
# 如果是字符串,尝试解析为JSON
|
380 |
+
if isinstance(result, str):
|
381 |
+
print(f"[DEBUG] 字符串内容 (前200字符): {result[:200]}")
|
382 |
+
try:
|
383 |
+
result = json.loads(result)
|
384 |
+
print("[DEBUG] 成功解析字符串为JSON对象")
|
385 |
+
except json.JSONDecodeError:
|
386 |
+
print("[DEBUG] 无法解析为JSON,返回原始字符串")
|
387 |
+
return result
|
388 |
+
|
389 |
+
# 处理字典类型的结果
|
390 |
+
if isinstance(result, dict):
|
391 |
+
print("[DEBUG] 处理字典类型结果")
|
392 |
+
|
393 |
+
# 1. 检查错误信息
|
394 |
+
if "code" in result and result["code"] != 0:
|
395 |
+
error_msg = result.get("message", "未知错误")
|
396 |
+
print(
|
397 |
+
f"[ERROR] API返回错误: code={result['code']}, message={error_msg}"
|
398 |
+
)
|
399 |
+
return f"错误: {error_msg}"
|
400 |
+
|
401 |
+
# 2. 检查直接包含文本结果的情况
|
402 |
+
if "result" in result and isinstance(result["result"], str):
|
403 |
+
print("[DEBUG] 找到直接结果字段")
|
404 |
+
return result["result"]
|
405 |
+
|
406 |
+
# 3. 检查lattice结构(详细结果)
|
407 |
+
if "lattice" in result and isinstance(result["lattice"], list):
|
408 |
+
print("[DEBUG] 解析lattice结构")
|
409 |
+
text_parts = []
|
410 |
+
for lattice in result["lattice"]:
|
411 |
+
if not isinstance(lattice, dict):
|
412 |
+
continue
|
413 |
+
|
414 |
+
# 获取json_1best内容
|
415 |
+
json_1best = lattice.get("json_1best", {})
|
416 |
+
if not json_1best or not isinstance(json_1best, dict):
|
417 |
+
continue
|
418 |
+
|
419 |
+
# 处理st字段 - 修正:st可能是字典或列表
|
420 |
+
st_content = json_1best.get("st")
|
421 |
+
st_list = []
|
422 |
+
if isinstance(st_content, dict):
|
423 |
+
st_list = [st_content] # 转为列表统一处理
|
424 |
+
elif isinstance(st_content, list):
|
425 |
+
st_list = st_content
|
426 |
+
|
427 |
+
for st in st_list:
|
428 |
+
if isinstance(st, str):
|
429 |
+
# 直接是字符串结果
|
430 |
+
text_parts.append(st)
|
431 |
+
elif isinstance(st, dict):
|
432 |
+
# 处理字典结构的st
|
433 |
+
rt = st.get("rt", [])
|
434 |
+
if not isinstance(rt, list):
|
435 |
+
continue
|
436 |
+
|
437 |
+
for item in rt:
|
438 |
+
if isinstance(item, dict):
|
439 |
+
ws_list = item.get("ws", [])
|
440 |
+
if isinstance(ws_list, list):
|
441 |
+
for ws in ws_list:
|
442 |
+
if isinstance(ws, dict):
|
443 |
+
cw_list = ws.get("cw", [])
|
444 |
+
if isinstance(cw_list, list):
|
445 |
+
for cw in cw_list:
|
446 |
+
if isinstance(cw, dict):
|
447 |
+
w = cw.get("w", "")
|
448 |
+
if w:
|
449 |
+
text_parts.append(w)
|
450 |
+
return "".join(text_parts)
|
451 |
+
|
452 |
+
# 4. 检查简化结构(直接包含st)
|
453 |
+
if "st" in result and isinstance(result["st"], list):
|
454 |
+
print("[DEBUG] 解析st结构")
|
455 |
+
text_parts = []
|
456 |
+
for st in result["st"]:
|
457 |
+
if isinstance(st, str):
|
458 |
+
text_parts.append(st)
|
459 |
+
elif isinstance(st, dict):
|
460 |
+
rt = st.get("rt", [])
|
461 |
+
if isinstance(rt, list):
|
462 |
+
for item in rt:
|
463 |
+
if isinstance(item, dict):
|
464 |
+
ws_list = item.get("ws", [])
|
465 |
+
if isinstance(ws_list, list):
|
466 |
+
for ws in ws_list:
|
467 |
+
if isinstance(ws, dict):
|
468 |
+
cw_list = ws.get("cw", [])
|
469 |
+
if isinstance(cw_list, list):
|
470 |
+
for cw in cw_list:
|
471 |
+
if isinstance(cw, dict):
|
472 |
+
w = cw.get("w", "")
|
473 |
+
if w:
|
474 |
+
text_parts.append(w)
|
475 |
+
return "".join(text_parts)
|
476 |
+
|
477 |
+
# 5. 其他未知结构
|
478 |
+
print("[WARNING] 无法识别的结果结构")
|
479 |
+
return json.dumps(result, indent=2, ensure_ascii=False)
|
480 |
+
|
481 |
+
# 6. 非字典类型结果
|
482 |
+
print(f"[WARNING] 非字典类型结果: {type(result)}")
|
483 |
+
return str(result)
|
484 |
+
|
485 |
+
|
486 |
+
def audio_to_str(appid, apikey, apisecret, file_path):
|
487 |
+
"""
|
488 |
+
调用讯飞开放平台接口,获取音频文件的转写结果。
|
489 |
+
|
490 |
+
参数:
|
491 |
+
appid (str): 讯飞开放平台的appid。
|
492 |
+
apikey (str): 讯飞开放平台的apikey。
|
493 |
+
apisecret (str): 讯飞开放平台的apisecret。
|
494 |
+
file_path (str): 音频文件路径。
|
495 |
+
|
496 |
+
返回值:
|
497 |
+
str: 转写结果文本,如果发生错误则返回None。
|
498 |
+
"""
|
499 |
+
# 检查文件是否存在
|
500 |
+
if not os.path.exists(file_path):
|
501 |
+
print(f"错误:文件 {file_path} 不存在")
|
502 |
+
return None
|
503 |
+
|
504 |
+
try:
|
505 |
+
# 1. 文件上传
|
506 |
+
file_uploader = FileUploader(
|
507 |
+
app_id=appid,
|
508 |
+
api_key=apikey,
|
509 |
+
api_secret=apisecret,
|
510 |
+
upload_file_path=file_path,
|
511 |
+
)
|
512 |
+
fileurl = file_uploader.upload_file()
|
513 |
+
if not fileurl:
|
514 |
+
print("文件上传失败")
|
515 |
+
return None
|
516 |
+
print("文件上传成功,fileurl:", fileurl)
|
517 |
+
|
518 |
+
# 2. 创建任务并查询结果
|
519 |
+
result_extractor = ResultExtractor(appid, apikey, apisecret)
|
520 |
+
print("\n------ 创建任务 -------")
|
521 |
+
create_response = result_extractor.task_create(fileurl)
|
522 |
+
|
523 |
+
# 调试输出创建响应
|
524 |
+
print(
|
525 |
+
f"[DEBUG] 创建任务响应: {json.dumps(create_response, indent=2, ensure_ascii=False)}"
|
526 |
+
)
|
527 |
+
|
528 |
+
if not isinstance(create_response, dict) or "data" not in create_response:
|
529 |
+
print("创建任务失败:", create_response)
|
530 |
+
return None
|
531 |
+
|
532 |
+
task_id = create_response["data"]["task_id"]
|
533 |
+
print(f"任务ID: {task_id}")
|
534 |
+
|
535 |
+
# 查询任务
|
536 |
+
print("\n------ 查询任务 -------")
|
537 |
+
print("任务转写中······")
|
538 |
+
max_attempts = 30
|
539 |
+
attempt = 0
|
540 |
+
|
541 |
+
while attempt < max_attempts:
|
542 |
+
result = result_extractor.task_query(task_id)
|
543 |
+
|
544 |
+
# 调试输出查询响应
|
545 |
+
print(f"\n[QUERY {attempt + 1}] 响应类型: {type(result)}")
|
546 |
+
if isinstance(result, dict):
|
547 |
+
print(
|
548 |
+
f"[QUERY {attempt + 1}] 响应内容: {json.dumps(result, indent=2, ensure_ascii=False)}"
|
549 |
+
)
|
550 |
+
else:
|
551 |
+
print(
|
552 |
+
f"[QUERY {attempt + 1}] 响应内容 (前200字符): {str(result)[:200]}"
|
553 |
+
)
|
554 |
+
|
555 |
+
# 检查响应是否有效
|
556 |
+
if not isinstance(result, dict):
|
557 |
+
print(f"无效响应类型: {type(result)}")
|
558 |
+
return None
|
559 |
+
|
560 |
+
# 检查API错误码
|
561 |
+
if "code" in result and result["code"] != 0:
|
562 |
+
error_msg = result.get("message", "未知错误")
|
563 |
+
print(f"API错误: code={result['code']}, message={error_msg}")
|
564 |
+
return None
|
565 |
+
|
566 |
+
# 获取任务状态
|
567 |
+
task_data = result.get("data", {})
|
568 |
+
task_status = task_data.get("task_status")
|
569 |
+
|
570 |
+
if not task_status:
|
571 |
+
print("响应中缺少任务状态字段")
|
572 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
573 |
+
return None
|
574 |
+
|
575 |
+
# 处理不同状态
|
576 |
+
if task_status in ["3", "4"]: # 任务已完成或回调完成
|
577 |
+
print("转写完成···")
|
578 |
+
|
579 |
+
# 提取结果
|
580 |
+
result_content = task_data.get("result")
|
581 |
+
if result_content is not None:
|
582 |
+
try:
|
583 |
+
result_text = result_extractor.extract_text(result_content)
|
584 |
+
print("\n转写结果:\n", result_text)
|
585 |
+
return result_text
|
586 |
+
except Exception as e:
|
587 |
+
print(f"\n提取文本时出错: {str(e)}")
|
588 |
+
print(f"错误详情:\n{traceback.format_exc()}")
|
589 |
+
print(
|
590 |
+
"原始结果内容:",
|
591 |
+
json.dumps(result_content, indent=2, ensure_ascii=False),
|
592 |
+
)
|
593 |
+
return None
|
594 |
+
else:
|
595 |
+
print("\n响应中缺少结果字段")
|
596 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
597 |
+
return None
|
598 |
+
|
599 |
+
elif task_status in ["1", "2"]: # 任务待处理或处理中
|
600 |
+
print(
|
601 |
+
f"任务状态:{task_status},等待中... (尝试 {attempt + 1}/{max_attempts})"
|
602 |
+
)
|
603 |
+
time.sleep(5)
|
604 |
+
attempt += 1
|
605 |
+
else:
|
606 |
+
print(f"未知任务状态:{task_status}")
|
607 |
+
print("完整响应:", json.dumps(result, indent=2, ensure_ascii=False))
|
608 |
+
return None
|
609 |
+
else:
|
610 |
+
print(f"超过最大查询次数({max_attempts}),任务可能仍在处理中")
|
611 |
+
return None
|
612 |
+
|
613 |
+
except Exception as e:
|
614 |
+
print(f"发生异常: {str(e)}")
|
615 |
+
print(f"错误详情:\n{traceback.format_exc()}")
|
616 |
+
return None
|
617 |
+
|
618 |
+
|
619 |
+
"""
|
620 |
+
1、通用文字识别,图像数据base64编码后大小不得超过10M
|
621 |
+
2、appid、apiSecret、apiKey请到讯飞开放平台控制台获取并填写到此demo中
|
622 |
+
3、支持中英文,支持手写和印刷文字。
|
623 |
+
4、在倾斜文字上效果有提升,同时支持部分生僻字的识别
|
624 |
+
"""
|
625 |
+
|
626 |
+
# 图像识别接口地址
|
627 |
+
URL = "https://api.xf-yun.com/v1/private/sf8e6aca1"
|
628 |
+
|
629 |
+
|
630 |
+
class AssembleHeaderException(Exception):
|
631 |
+
def __init__(self, msg):
|
632 |
+
self.message = msg
|
633 |
+
|
634 |
+
|
635 |
+
class Url:
|
636 |
+
def __init__(self, host, path, schema):
|
637 |
+
self.host = host
|
638 |
+
self.path = path
|
639 |
+
self.schema = schema
|
640 |
+
pass
|
641 |
+
|
642 |
+
|
643 |
+
# calculate sha256 and encode to base64
|
644 |
+
def sha256base64(data):
|
645 |
+
sha256 = hashlib.sha256()
|
646 |
+
sha256.update(data)
|
647 |
+
digest = base64.b64encode(sha256.digest()).decode(encoding="utf-8")
|
648 |
+
return digest
|
649 |
+
|
650 |
+
|
651 |
+
def parse_url(requset_url):
|
652 |
+
stidx = requset_url.index("://")
|
653 |
+
host = requset_url[stidx + 3 :]
|
654 |
+
schema = requset_url[: stidx + 3]
|
655 |
+
edidx = host.index("/")
|
656 |
+
if edidx <= 0:
|
657 |
+
raise AssembleHeaderException("invalid request url:" + requset_url)
|
658 |
+
path = host[edidx:]
|
659 |
+
host = host[:edidx]
|
660 |
+
u = Url(host, path, schema)
|
661 |
+
return u
|
662 |
+
|
663 |
+
|
664 |
+
# build websocket auth request url
|
665 |
+
def assemble_ws_auth_url(requset_url, method="POST", api_key="", api_secret=""):
|
666 |
+
u = parse_url(requset_url)
|
667 |
+
host = u.host
|
668 |
+
path = u.path
|
669 |
+
now = datetime.datetime.now()
|
670 |
+
date = format_date_time(mktime(now.timetuple()))
|
671 |
+
# print(date) # 可选:打印Date值
|
672 |
+
|
673 |
+
signature_origin = "host: {}\ndate: {}\n{} {} HTTP/1.1".format(
|
674 |
+
host, date, method, path
|
675 |
+
)
|
676 |
+
# print(signature_origin) # 可选:打印签名原文
|
677 |
+
signature_sha = hmac.new(
|
678 |
+
api_secret.encode("utf-8"),
|
679 |
+
signature_origin.encode("utf-8"),
|
680 |
+
digestmod=hashlib.sha256,
|
681 |
+
).digest()
|
682 |
+
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
683 |
+
authorization_origin = (
|
684 |
+
'api_key="%s", algorithm="%s", headers="%s", signature="%s"'
|
685 |
+
% (api_key, "hmac-sha256", "host date request-line", signature_sha)
|
686 |
+
)
|
687 |
+
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
|
688 |
+
encoding="utf-8"
|
689 |
+
)
|
690 |
+
# print(authorization_origin) # 可选:打印鉴权原文
|
691 |
+
values = {"host": host, "date": date, "authorization": authorization}
|
692 |
+
|
693 |
+
return requset_url + "?" + urlencode(values)
|
694 |
+
|
695 |
+
|
696 |
+
def image_to_str(endpoint=None, key=None, unused_param=None, file_path=None):
|
697 |
+
"""
|
698 |
+
调用Azure Computer Vision API识别图片中的文字。
|
699 |
+
|
700 |
+
参数:
|
701 |
+
endpoint (str): Azure Computer Vision endpoint URL。
|
702 |
+
key (str): Azure Computer Vision API key。
|
703 |
+
unused_param (str): 未使用的参数,保持兼容性。
|
704 |
+
file_path (str): 图片文件路径。
|
705 |
+
|
706 |
+
返回值:
|
707 |
+
str: 图片中的文字识别结果,如果发生错误则返回None。
|
708 |
+
"""
|
709 |
+
|
710 |
+
# 默认配置
|
711 |
+
if endpoint is None:
|
712 |
+
endpoint = "https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/"
|
713 |
+
if key is None:
|
714 |
+
key = "45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ"
|
715 |
+
|
716 |
+
try:
|
717 |
+
# 读取图片文件
|
718 |
+
with open(file_path, "rb") as f:
|
719 |
+
image_data = f.read()
|
720 |
+
|
721 |
+
# 构造请求URL
|
722 |
+
analyze_url = endpoint.rstrip('/') + "/vision/v3.2/read/analyze"
|
723 |
+
|
724 |
+
# 设置请求头
|
725 |
+
headers = {
|
726 |
+
'Ocp-Apim-Subscription-Key': key,
|
727 |
+
'Content-Type': 'application/octet-stream'
|
728 |
+
}
|
729 |
+
|
730 |
+
# 发送POST请求开始分析
|
731 |
+
response = requests.post(analyze_url, headers=headers, data=image_data)
|
732 |
+
|
733 |
+
if response.status_code != 202:
|
734 |
+
print(f"分析请求失败: {response.status_code}, {response.text}")
|
735 |
+
return None
|
736 |
+
|
737 |
+
# 获取操作位置
|
738 |
+
operation_url = response.headers["Operation-Location"]
|
739 |
+
|
740 |
+
# 轮询结果
|
741 |
+
import time
|
742 |
+
while True:
|
743 |
+
result_response = requests.get(operation_url, headers={'Ocp-Apim-Subscription-Key': key})
|
744 |
+
result = result_response.json()
|
745 |
+
|
746 |
+
if result["status"] == "succeeded":
|
747 |
+
# 提取文字
|
748 |
+
text_results = []
|
749 |
+
if "analyzeResult" in result and "readResults" in result["analyzeResult"]:
|
750 |
+
for read_result in result["analyzeResult"]["readResults"]:
|
751 |
+
for line in read_result["lines"]:
|
752 |
+
text_results.append(line["text"])
|
753 |
+
|
754 |
+
return " ".join(text_results) if text_results else ""
|
755 |
+
|
756 |
+
elif result["status"] == "failed":
|
757 |
+
print(f"文字识别失败: {result}")
|
758 |
+
return None
|
759 |
+
|
760 |
+
# 等待1秒后重试
|
761 |
+
time.sleep(1)
|
762 |
+
|
763 |
+
except Exception as e:
|
764 |
+
print(f"发生异常: {e}")
|
765 |
+
return None
|
766 |
+
|
767 |
+
|
768 |
+
if __name__ == "__main__":
|
769 |
+
# 输入讯飞开放平台的 appid,secret、key 和文件路径
|
770 |
+
appid = "33c1b63d"
|
771 |
+
apikey = "40bf7cd82e31ace30a9cfb76309a43a3"
|
772 |
+
apisecret = "OTY1YzIyZWM3YTg0OWZiMGE2ZjA2ZmE4"
|
773 |
+
audio_path = r"audio_sample_little.wav" # 确保文件路径正确
|
774 |
+
image_path = r"1.png" # 确保文件路径正确
|
775 |
+
|
776 |
+
# 音频转文字
|
777 |
+
audio_text = audio_to_str(appid, apikey, apisecret, audio_path)
|
778 |
+
# 图片转文字
|
779 |
+
image_text = image_to_str(endpoint="https://ai-siyuwang5414995ai361208251338.cognitiveservices.azure.com/", key="45PYY2Av9CdMCveAjVG43MGKrnHzSxdiFTK9mWBgrOsMAHavxKj0JQQJ99BDACHYHv6XJ3w3AAAAACOGeVpQ", unused_param=None, file_path=image_path)
|
780 |
+
|
781 |
+
print("-"* 20)
|
782 |
+
|
783 |
+
print("\n音频转文字结果:", audio_text)
|
784 |
+
print("\n图片转文字结果:", image_text)
|
785 |
+
|
786 |
+
|
787 |
+
def azure_speech_to_text(speech_key, speech_region, audio_file_path):
|
788 |
+
"""
|
789 |
+
使用Azure Speech服务将音频文件转换为文本。
|
790 |
+
|
791 |
+
参数:
|
792 |
+
speech_key (str): Azure Speech服务的API密钥。
|
793 |
+
speech_region (str): Azure Speech服务的区域。
|
794 |
+
audio_file_path (str): 音频文件路径。
|
795 |
+
|
796 |
+
返回值:
|
797 |
+
str: 转换后的文本,如果发生错误则返回None。
|
798 |
+
"""
|
799 |
+
try:
|
800 |
+
# 设置语音配置
|
801 |
+
speech_config = speechsdk.SpeechConfig(subscription=speech_key, region=speech_region)
|
802 |
+
speech_config.speech_recognition_language = "zh-CN" # 设置为中文
|
803 |
+
|
804 |
+
# 设置音频配置
|
805 |
+
audio_config = speechsdk.audio.AudioConfig(filename=audio_file_path)
|
806 |
+
|
807 |
+
# 创建语音识别器
|
808 |
+
speech_recognizer = speechsdk.SpeechRecognizer(speech_config=speech_config, audio_config=audio_config)
|
809 |
+
|
810 |
+
# 执行语音识别
|
811 |
+
result = speech_recognizer.recognize_once()
|
812 |
+
|
813 |
+
# 检查识别结果
|
814 |
+
if result.reason == speechsdk.ResultReason.RecognizedSpeech:
|
815 |
+
print(f"Azure Speech识别成功: {result.text}")
|
816 |
+
return result.text
|
817 |
+
elif result.reason == speechsdk.ResultReason.NoMatch:
|
818 |
+
print("Azure Speech未识别到语音")
|
819 |
+
return None
|
820 |
+
elif result.reason == speechsdk.ResultReason.Canceled:
|
821 |
+
cancellation_details = result.cancellation_details
|
822 |
+
print(f"Azure Speech识别被取消: {cancellation_details.reason}")
|
823 |
+
if cancellation_details.reason == speechsdk.CancellationReason.Error:
|
824 |
+
print(f"错误详情: {cancellation_details.error_details}")
|
825 |
+
return None
|
826 |
+
except Exception as e:
|
827 |
+
print(f"Azure Speech识别出错: {str(e)}")
|
828 |
+
return None
|