Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
@@ -7,52 +7,7 @@ import json
|
|
7 |
import chardet
|
8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
import time
|
10 |
-
from functools import lru_cache
|
11 |
-
from sqlalchemy import create_engine
|
12 |
-
from sqlalchemy.orm import sessionmaker
|
13 |
-
from contextlib import contextmanager
|
14 |
-
import logging
|
15 |
-
import networkx as nx
|
16 |
-
from pyvis.network import Network
|
17 |
-
import pandas as pd
|
18 |
-
import matplotlib.pyplot as plt
|
19 |
-
from gqlalchemy import Memgraph
|
20 |
-
from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
|
21 |
-
from relation_extraction.hparams import hparams # 引入模型超参数
|
22 |
-
|
23 |
-
# ======================== 数据库模块 ========================
|
24 |
-
MEMGRAPH_HOST = '18.159.132.161'
|
25 |
-
MEMGRAPH_PORT = 7687
|
26 |
-
MEMGRAPH_USERNAME = '[email protected]'
|
27 |
-
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
28 |
-
|
29 |
-
# 初始化 Memgraph 连接
|
30 |
-
memgraph = get_memgraph_conn()
|
31 |
-
|
32 |
-
# 初始化关系抽取模型
|
33 |
-
relation_predictor = RelationPredictor(hparams)
|
34 |
-
|
35 |
-
# ======================== 关系抽取功能整合 ========================
|
36 |
-
def extract_and_save_relations(text, entity1, entity2):
|
37 |
-
"""
|
38 |
-
使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
|
39 |
-
"""
|
40 |
-
try:
|
41 |
-
# 调用关系抽取模型
|
42 |
-
result = relation_predictor.predict_one(text, entity1, entity2)
|
43 |
-
if result is None:
|
44 |
-
return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
|
45 |
-
|
46 |
-
# 提取关系
|
47 |
-
entity1, relation, entity2 = result
|
48 |
-
|
49 |
-
# 保存到 Memgraph
|
50 |
-
insert_to_memgraph(memgraph, entity1, relation, entity2)
|
51 |
-
return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
|
52 |
-
except Exception as e:
|
53 |
-
logging.error(f"关系抽取失败: {e}")
|
54 |
-
return f"❌ 关系抽取失败: {e}"
|
55 |
-
|
56 |
# ======================== 数据库模块 ========================
|
57 |
from sqlalchemy import create_engine
|
58 |
from sqlalchemy.orm import sessionmaker
|
@@ -63,46 +18,6 @@ from pyvis.network import Network
|
|
63 |
import pandas as pd
|
64 |
import matplotlib.pyplot as plt
|
65 |
|
66 |
-
|
67 |
-
from gqlalchemy import Memgraph
|
68 |
-
import os
|
69 |
-
MEMGRAPH_HOST = '18.159.132.161'
|
70 |
-
MEMGRAPH_PORT = 7687
|
71 |
-
MEMGRAPH_USERNAME = '[email protected]'
|
72 |
-
MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
|
73 |
-
|
74 |
-
def hello_memgraph():
|
75 |
-
"""测试 Memgraph 数据库连接并进行健康检查"""
|
76 |
-
try:
|
77 |
-
connection = Memgraph(
|
78 |
-
host=os.environ["MEMGRAPH_HOST"],
|
79 |
-
port=int(os.environ["MEMGRAPH_PORT"]),
|
80 |
-
username=os.environ["MEMGRAPH_USERNAME"],
|
81 |
-
password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
|
82 |
-
encrypted=True,
|
83 |
-
ssl_verify=True,
|
84 |
-
ca_path="/etc/ssl/certs/memgraph.crt"
|
85 |
-
)
|
86 |
-
|
87 |
-
# 健康检查查询
|
88 |
-
health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
|
89 |
-
health_status = next(health)["value"]["status"]
|
90 |
-
|
91 |
-
# 创建测试节点
|
92 |
-
connection.execute(
|
93 |
-
'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
|
94 |
-
{"ts": datetime.now().isoformat()}
|
95 |
-
)
|
96 |
-
|
97 |
-
return f"✅ 连接正常 | 状态: {health_status}"
|
98 |
-
|
99 |
-
except Exception as e:
|
100 |
-
logging.error(f"连接失败: {str(e)}", exc_info=True)
|
101 |
-
return f"❌ 连接失败: {str(e)}"
|
102 |
-
finally:
|
103 |
-
connection.close()
|
104 |
-
|
105 |
-
|
106 |
# 配置日志
|
107 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
108 |
|
@@ -593,55 +508,35 @@ def process_text(text, model_type="bert"):
|
|
593 |
return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
|
594 |
|
595 |
# ======================== 知识图谱可视化 ========================
|
596 |
-
import matplotlib.pyplot as plt
|
597 |
-
import networkx as nx
|
598 |
-
import tempfile
|
599 |
-
import os
|
600 |
-
import logging
|
601 |
-
from matplotlib import font_manager
|
602 |
-
|
603 |
-
# 这个函数用于查找并验证中文字体路径
|
604 |
-
def find_chinese_font():
|
605 |
-
# 尝试查找 Noto Sans CJK 字体
|
606 |
-
font_paths = [
|
607 |
-
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", # Noto CJK 字体
|
608 |
-
"/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # 微软雅黑
|
609 |
-
]
|
610 |
-
|
611 |
-
for font_path in font_paths:
|
612 |
-
if os.path.exists(font_path):
|
613 |
-
logging.info(f"Found font at {font_path}")
|
614 |
-
return font_path
|
615 |
-
|
616 |
-
logging.error("No Chinese font found!")
|
617 |
-
return None
|
618 |
-
|
619 |
def generate_kg_image(entities, relations):
|
620 |
"""
|
621 |
-
|
622 |
"""
|
623 |
try:
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
plt.rcParams['font.family'] = ['DejaVu Sans']
|
633 |
|
634 |
-
|
|
|
|
|
635 |
|
636 |
-
# ===
|
637 |
G = nx.DiGraph()
|
638 |
entity_colors = {
|
639 |
-
'PER': '#FF6B6B', #
|
640 |
-
'ORG': '#4ECDC4', #
|
641 |
-
'LOC': '#45B7D1', #
|
642 |
-
'TIME': '#96CEB4'
|
|
|
643 |
}
|
644 |
|
|
|
645 |
for entity in entities:
|
646 |
G.add_node(
|
647 |
entity["text"],
|
@@ -649,6 +544,7 @@ def generate_kg_image(entities, relations):
|
|
649 |
color=entity_colors.get(entity['type'], '#D3D3D3')
|
650 |
)
|
651 |
|
|
|
652 |
for relation in relations:
|
653 |
if relation["head"] in G.nodes and relation["tail"] in G.nodes:
|
654 |
G.add_edge(
|
@@ -657,15 +553,15 @@ def generate_kg_image(entities, relations):
|
|
657 |
label=relation["relation"]
|
658 |
)
|
659 |
|
660 |
-
# ===
|
661 |
-
plt.figure(figsize=(12, 8), dpi=150)
|
662 |
-
pos = nx.spring_layout(G, k=0.7, seed=42)
|
663 |
|
|
|
664 |
nx.draw_networkx_nodes(
|
665 |
G, pos,
|
666 |
node_color=[G.nodes[n]['color'] for n in G.nodes],
|
667 |
-
node_size=800
|
668 |
-
alpha=0.9
|
669 |
)
|
670 |
nx.draw_networkx_edges(
|
671 |
G, pos,
|
@@ -675,44 +571,35 @@ def generate_kg_image(entities, relations):
|
|
675 |
arrowsize=20
|
676 |
)
|
677 |
|
678 |
-
|
679 |
nx.draw_networkx_labels(
|
680 |
G, pos,
|
681 |
-
labels=
|
682 |
font_size=10,
|
683 |
-
font_family=
|
684 |
-
font_weight='bold'
|
685 |
)
|
686 |
-
|
687 |
-
edge_labels = nx.get_edge_attributes(G, 'label')
|
688 |
nx.draw_networkx_edge_labels(
|
689 |
G, pos,
|
690 |
-
edge_labels=
|
691 |
font_size=8,
|
692 |
-
font_family=
|
693 |
)
|
694 |
|
695 |
plt.axis('off')
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
# 打印路径以方便调试
|
703 |
-
logging.info(f"Saving graph image to {output_path}")
|
704 |
-
|
705 |
-
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
|
706 |
plt.close()
|
707 |
-
|
708 |
-
return
|
709 |
-
|
710 |
except Exception as e:
|
711 |
-
logging.error(f"
|
712 |
return None
|
713 |
|
714 |
|
715 |
-
# ======================== 文件处理 ========================
|
716 |
def process_file(file, model_type="bert"):
|
717 |
try:
|
718 |
with open(file.name, 'rb') as f:
|
@@ -864,11 +751,13 @@ with gr.Blocks(css="""
|
|
864 |
|
865 |
with gr.Tab("📄 文本分析"):
|
866 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
867 |
-
|
868 |
-
|
869 |
-
|
870 |
-
|
871 |
-
|
|
|
|
|
872 |
|
873 |
with gr.Tab("🗂 文件分析"):
|
874 |
file_input = gr.File(file_types=[".txt", ".json"])
|
@@ -903,10 +792,4 @@ with gr.Blocks(css="""
|
|
903 |
import_output = gr.Textbox(label="导入日志")
|
904 |
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
|
905 |
|
906 |
-
gr.Markdown("### 测试 Memgraph 数据库连接")
|
907 |
-
memgraph_btn = gr.Button("测试 Memgraph 连接")
|
908 |
-
memgraph_output = gr.Textbox(label="连接测试日志")
|
909 |
-
memgraph_btn.click(fn=hello_memgraph, outputs=memgraph_output)
|
910 |
-
|
911 |
-
|
912 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
7 |
import chardet
|
8 |
from sklearn.metrics import precision_score, recall_score, f1_score
|
9 |
import time
|
10 |
+
from functools import lru_cache # 添加这行导入
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# ======================== 数据库模块 ========================
|
12 |
from sqlalchemy import create_engine
|
13 |
from sqlalchemy.orm import sessionmaker
|
|
|
18 |
import pandas as pd
|
19 |
import matplotlib.pyplot as plt
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# 配置日志
|
22 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
23 |
|
|
|
508 |
return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
|
509 |
|
510 |
# ======================== 知识图谱可视化 ========================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
511 |
def generate_kg_image(entities, relations):
|
512 |
"""
|
513 |
+
生成知识图谱的图片并保存到临时文件(Hugging Face适配版)
|
514 |
"""
|
515 |
try:
|
516 |
+
import tempfile
|
517 |
+
import matplotlib.pyplot as plt
|
518 |
+
import networkx as nx
|
519 |
+
import os
|
520 |
+
|
521 |
+
# === 1. 强制设置中文字体 ===
|
522 |
+
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC'] # Hugging Face内置字体
|
523 |
+
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
|
|
|
524 |
|
525 |
+
# === 2. 检查输入数据 ===
|
526 |
+
if not entities or not relations:
|
527 |
+
return None
|
528 |
|
529 |
+
# === 3. 创建图谱 ===
|
530 |
G = nx.DiGraph()
|
531 |
entity_colors = {
|
532 |
+
'PER': '#FF6B6B', # 红色
|
533 |
+
'ORG': '#4ECDC4', # 青色
|
534 |
+
'LOC': '#45B7D1', # 蓝色
|
535 |
+
'TIME': '#96CEB4', # 绿色
|
536 |
+
'TITLE': '#D4A5A5' # 灰色
|
537 |
}
|
538 |
|
539 |
+
# 添加节点(实体)
|
540 |
for entity in entities:
|
541 |
G.add_node(
|
542 |
entity["text"],
|
|
|
544 |
color=entity_colors.get(entity['type'], '#D3D3D3')
|
545 |
)
|
546 |
|
547 |
+
# 添加边(关系)
|
548 |
for relation in relations:
|
549 |
if relation["head"] in G.nodes and relation["tail"] in G.nodes:
|
550 |
G.add_edge(
|
|
|
553 |
label=relation["relation"]
|
554 |
)
|
555 |
|
556 |
+
# === 4. 绘图配置 ===
|
557 |
+
plt.figure(figsize=(12, 8), dpi=150) # 降低DPI以节省内存
|
558 |
+
pos = nx.spring_layout(G, k=0.7, seed=42) # 固定随机种子保证布局稳定
|
559 |
|
560 |
+
# 绘制节点和边
|
561 |
nx.draw_networkx_nodes(
|
562 |
G, pos,
|
563 |
node_color=[G.nodes[n]['color'] for n in G.nodes],
|
564 |
+
node_size=800
|
|
|
565 |
)
|
566 |
nx.draw_networkx_edges(
|
567 |
G, pos,
|
|
|
571 |
arrowsize=20
|
572 |
)
|
573 |
|
574 |
+
# === 5. 绘制中文标签(关键修改点)===
|
575 |
nx.draw_networkx_labels(
|
576 |
G, pos,
|
577 |
+
labels={n: G.nodes[n]['label'] for n in G.nodes},
|
578 |
font_size=10,
|
579 |
+
font_family='Noto Sans CJK SC' # 显式指定字体
|
|
|
580 |
)
|
|
|
|
|
581 |
nx.draw_networkx_edge_labels(
|
582 |
G, pos,
|
583 |
+
edge_labels=nx.get_edge_attributes(G, 'label'),
|
584 |
font_size=8,
|
585 |
+
font_family='Noto Sans CJK SC' # 显式指定字体
|
586 |
)
|
587 |
|
588 |
plt.axis('off')
|
589 |
+
|
590 |
+
# === 6. 保存到临时文件 ===
|
591 |
+
temp_dir = tempfile.mkdtemp()
|
592 |
+
file_path = os.path.join(temp_dir, "kg.png")
|
593 |
+
plt.savefig(file_path, bbox_inches='tight')
|
|
|
|
|
|
|
|
|
|
|
594 |
plt.close()
|
595 |
+
|
596 |
+
return file_path
|
597 |
+
|
598 |
except Exception as e:
|
599 |
+
logging.error(f"生成知识图谱图片失败: {str(e)}")
|
600 |
return None
|
601 |
|
602 |
|
|
|
603 |
def process_file(file, model_type="bert"):
|
604 |
try:
|
605 |
with open(file.name, 'rb') as f:
|
|
|
751 |
|
752 |
with gr.Tab("📄 文本分析"):
|
753 |
input_text = gr.Textbox(lines=6, label="输入文本")
|
754 |
+
model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
|
755 |
+
btn = gr.Button("开始分析")
|
756 |
+
out1 = gr.Textbox(label="识别实体")
|
757 |
+
out2 = gr.Textbox(label="识别关系")
|
758 |
+
out3 = gr.HTML(label="知识图谱") # 使用HTML组件显示文本格式的知识图谱
|
759 |
+
out4 = gr.Textbox(label="耗时")
|
760 |
+
btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
|
761 |
|
762 |
with gr.Tab("🗂 文件分析"):
|
763 |
file_input = gr.File(file_types=[".txt", ".json"])
|
|
|
792 |
import_output = gr.Textbox(label="导入日志")
|
793 |
import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
|
794 |
|
|
|
|
|
|
|
|
|
|
|
|
|
795 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|