chen666-666 commited on
Commit
763c565
·
verified ·
1 Parent(s): b99a17e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -164
app.py CHANGED
@@ -7,52 +7,7 @@ import json
7
  import chardet
8
  from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
- from functools import lru_cache
11
- from sqlalchemy import create_engine
12
- from sqlalchemy.orm import sessionmaker
13
- from contextlib import contextmanager
14
- import logging
15
- import networkx as nx
16
- from pyvis.network import Network
17
- import pandas as pd
18
- import matplotlib.pyplot as plt
19
- from gqlalchemy import Memgraph
20
- from mcp_use import RelationPredictor, insert_to_memgraph, get_memgraph_conn # 引入mcp_use中的功能
21
- from relation_extraction.hparams import hparams # 引入模型超参数
22
-
23
- # ======================== 数据库模块 ========================
24
- MEMGRAPH_HOST = '18.159.132.161'
25
- MEMGRAPH_PORT = 7687
26
- MEMGRAPH_USERNAME = '[email protected]'
27
- MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
28
-
29
- # 初始化 Memgraph 连接
30
- memgraph = get_memgraph_conn()
31
-
32
- # 初始化关系抽取模型
33
- relation_predictor = RelationPredictor(hparams)
34
-
35
- # ======================== 关系抽取功能整合 ========================
36
- def extract_and_save_relations(text, entity1, entity2):
37
- """
38
- 使用 mcp_use.py 中的 RelationPredictor 提取关系,并保存到 Memgraph
39
- """
40
- try:
41
- # 调用关系抽取模型
42
- result = relation_predictor.predict_one(text, entity1, entity2)
43
- if result is None:
44
- return f"❌ 未找到实体 '{entity1}' 或 '{entity2}'"
45
-
46
- # 提取关系
47
- entity1, relation, entity2 = result
48
-
49
- # 保存到 Memgraph
50
- insert_to_memgraph(memgraph, entity1, relation, entity2)
51
- return f"✅ 已写入 Memgraph:({entity1})-[:{relation}]->({entity2})"
52
- except Exception as e:
53
- logging.error(f"关系抽取失败: {e}")
54
- return f"❌ 关系抽取失败: {e}"
55
-
56
  # ======================== 数据库模块 ========================
57
  from sqlalchemy import create_engine
58
  from sqlalchemy.orm import sessionmaker
@@ -63,46 +18,6 @@ from pyvis.network import Network
63
  import pandas as pd
64
  import matplotlib.pyplot as plt
65
 
66
-
67
- from gqlalchemy import Memgraph
68
- import os
69
- MEMGRAPH_HOST = '18.159.132.161'
70
- MEMGRAPH_PORT = 7687
71
- MEMGRAPH_USERNAME = '[email protected]'
72
- MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "<YOUR MEMGRAPH PASSWORD HERE>")
73
-
74
- def hello_memgraph():
75
- """测试 Memgraph 数据库连接并进行健康检查"""
76
- try:
77
- connection = Memgraph(
78
- host=os.environ["MEMGRAPH_HOST"],
79
- port=int(os.environ["MEMGRAPH_PORT"]),
80
- username=os.environ["MEMGRAPH_USERNAME"],
81
- password=os.environ["MEMGRAPH_PASSWORD"], # 强制从环境变量获取
82
- encrypted=True,
83
- ssl_verify=True,
84
- ca_path="/etc/ssl/certs/memgraph.crt"
85
- )
86
-
87
- # 健康检查查询
88
- health = connection.execute_and_fetch("CALL mg.get('memgraph') YIELD value;")
89
- health_status = next(health)["value"]["status"]
90
-
91
- # 创建测试节点
92
- connection.execute(
93
- 'CREATE (n:ConnectionTest { message: "Hello Memgraph", ts: $ts })',
94
- {"ts": datetime.now().isoformat()}
95
- )
96
-
97
- return f"✅ 连接正常 | 状态: {health_status}"
98
-
99
- except Exception as e:
100
- logging.error(f"连接失败: {str(e)}", exc_info=True)
101
- return f"❌ 连接失败: {str(e)}"
102
- finally:
103
- connection.close()
104
-
105
-
106
  # 配置日志
107
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
108
 
@@ -593,55 +508,35 @@ def process_text(text, model_type="bert"):
593
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
594
 
595
  # ======================== 知识图谱可视化 ========================
596
- import matplotlib.pyplot as plt
597
- import networkx as nx
598
- import tempfile
599
- import os
600
- import logging
601
- from matplotlib import font_manager
602
-
603
- # 这个函数用于查找并验证中文字体路径
604
- def find_chinese_font():
605
- # 尝试查找 Noto Sans CJK 字体
606
- font_paths = [
607
- "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", # Noto CJK 字体
608
- "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc" # 微软雅黑
609
- ]
610
-
611
- for font_path in font_paths:
612
- if os.path.exists(font_path):
613
- logging.info(f"Found font at {font_path}")
614
- return font_path
615
-
616
- logging.error("No Chinese font found!")
617
- return None
618
-
619
  def generate_kg_image(entities, relations):
620
  """
621
- 中文知识图谱生成函数,支持自动匹配系统中的中文字体,避免中文显示为方框。
622
  """
623
  try:
624
- # === 1. 确保使用合适的中文字体 ===
625
- chinese_font = find_chinese_font() # 调用查找字体函数
626
- if chinese_font:
627
- font_prop = font_manager.FontProperties(fname=chinese_font)
628
- plt.rcParams['font.family'] = font_prop.get_name()
629
- else:
630
- # 如果字体路径未找到,使用默认字体(DejaVu Sans)
631
- logging.warning("Using default font")
632
- plt.rcParams['font.family'] = ['DejaVu Sans']
633
 
634
- plt.rcParams['axes.unicode_minus'] = False
 
 
635
 
636
- # === 2. 创建图谱 ===
637
  G = nx.DiGraph()
638
  entity_colors = {
639
- 'PER': '#FF6B6B', # 人物-红色
640
- 'ORG': '#4ECDC4', # 组织-青色
641
- 'LOC': '#45B7D1', # 地点-蓝色
642
- 'TIME': '#96CEB4' # 时间-绿色
 
643
  }
644
 
 
645
  for entity in entities:
646
  G.add_node(
647
  entity["text"],
@@ -649,6 +544,7 @@ def generate_kg_image(entities, relations):
649
  color=entity_colors.get(entity['type'], '#D3D3D3')
650
  )
651
 
 
652
  for relation in relations:
653
  if relation["head"] in G.nodes and relation["tail"] in G.nodes:
654
  G.add_edge(
@@ -657,15 +553,15 @@ def generate_kg_image(entities, relations):
657
  label=relation["relation"]
658
  )
659
 
660
- # === 3. 绘图配置 ===
661
- plt.figure(figsize=(12, 8), dpi=150)
662
- pos = nx.spring_layout(G, k=0.7, seed=42)
663
 
 
664
  nx.draw_networkx_nodes(
665
  G, pos,
666
  node_color=[G.nodes[n]['color'] for n in G.nodes],
667
- node_size=800,
668
- alpha=0.9
669
  )
670
  nx.draw_networkx_edges(
671
  G, pos,
@@ -675,44 +571,35 @@ def generate_kg_image(entities, relations):
675
  arrowsize=20
676
  )
677
 
678
- node_labels = {n: G.nodes[n]['label'] for n in G.nodes}
679
  nx.draw_networkx_labels(
680
  G, pos,
681
- labels=node_labels,
682
  font_size=10,
683
- font_family=font_prop.get_name() if chinese_font else 'SimHei',
684
- font_weight='bold'
685
  )
686
-
687
- edge_labels = nx.get_edge_attributes(G, 'label')
688
  nx.draw_networkx_edge_labels(
689
  G, pos,
690
- edge_labels=edge_labels,
691
  font_size=8,
692
- font_family=font_prop.get_name() if chinese_font else 'SimHei'
693
  )
694
 
695
  plt.axis('off')
696
- plt.tight_layout()
697
-
698
- # === 4. 保存图片 ===
699
- temp_dir = tempfile.mkdtemp() # 确保在 Docker 容器中有权限写入
700
- output_path = os.path.join(temp_dir, "kg.png")
701
-
702
- # 打印路径以方便调试
703
- logging.info(f"Saving graph image to {output_path}")
704
-
705
- plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1)
706
  plt.close()
707
-
708
- return output_path
709
-
710
  except Exception as e:
711
- logging.error(f"[ERROR] 图谱生成失败: {str(e)}")
712
  return None
713
 
714
 
715
- # ======================== 文件处理 ========================
716
  def process_file(file, model_type="bert"):
717
  try:
718
  with open(file.name, 'rb') as f:
@@ -864,11 +751,13 @@ with gr.Blocks(css="""
864
 
865
  with gr.Tab("📄 文本分析"):
866
  input_text = gr.Textbox(lines=6, label="输入文本")
867
- entity1 = gr.Textbox(label="实体1")
868
- entity2 = gr.Textbox(label="实体2")
869
- btn = gr.Button("提取关系并保存到 Memgraph")
870
- output = gr.Textbox(label="结果")
871
- btn.click(fn=extract_and_save_relations, inputs=[input_text, entity1, entity2], outputs=output)
 
 
872
 
873
  with gr.Tab("🗂 文件分析"):
874
  file_input = gr.File(file_types=[".txt", ".json"])
@@ -903,10 +792,4 @@ with gr.Blocks(css="""
903
  import_output = gr.Textbox(label="导入日志")
904
  import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
905
 
906
- gr.Markdown("### 测试 Memgraph 数据库连接")
907
- memgraph_btn = gr.Button("测试 Memgraph 连接")
908
- memgraph_output = gr.Textbox(label="连接测试日志")
909
- memgraph_btn.click(fn=hello_memgraph, outputs=memgraph_output)
910
-
911
-
912
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
7
  import chardet
8
  from sklearn.metrics import precision_score, recall_score, f1_score
9
  import time
10
+ from functools import lru_cache # 添加这行导入
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # ======================== 数据库模块 ========================
12
  from sqlalchemy import create_engine
13
  from sqlalchemy.orm import sessionmaker
 
18
  import pandas as pd
19
  import matplotlib.pyplot as plt
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # 配置日志
22
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
23
 
 
508
  return ent_text, rel_text, kg_text, f"{total_duration:.2f} 秒"
509
 
510
  # ======================== 知识图谱可视化 ========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
511
  def generate_kg_image(entities, relations):
512
  """
513
+ 生成知识图谱的图片并保存到临时文件(Hugging Face适配版)
514
  """
515
  try:
516
+ import tempfile
517
+ import matplotlib.pyplot as plt
518
+ import networkx as nx
519
+ import os
520
+
521
+ # === 1. 强制设置中文字体 ===
522
+ plt.rcParams['font.sans-serif'] = ['Noto Sans CJK SC'] # Hugging Face内置字体
523
+ plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
 
524
 
525
+ # === 2. 检查输入数据 ===
526
+ if not entities or not relations:
527
+ return None
528
 
529
+ # === 3. 创建图谱 ===
530
  G = nx.DiGraph()
531
  entity_colors = {
532
+ 'PER': '#FF6B6B', # 红色
533
+ 'ORG': '#4ECDC4', # 青色
534
+ 'LOC': '#45B7D1', # 蓝色
535
+ 'TIME': '#96CEB4', # 绿色
536
+ 'TITLE': '#D4A5A5' # 灰色
537
  }
538
 
539
+ # 添加节点(实体)
540
  for entity in entities:
541
  G.add_node(
542
  entity["text"],
 
544
  color=entity_colors.get(entity['type'], '#D3D3D3')
545
  )
546
 
547
+ # 添加边(关系)
548
  for relation in relations:
549
  if relation["head"] in G.nodes and relation["tail"] in G.nodes:
550
  G.add_edge(
 
553
  label=relation["relation"]
554
  )
555
 
556
+ # === 4. 绘图配置 ===
557
+ plt.figure(figsize=(12, 8), dpi=150) # 降低DPI以节省内存
558
+ pos = nx.spring_layout(G, k=0.7, seed=42) # 固定随机种子保证布局稳定
559
 
560
+ # 绘制节点和边
561
  nx.draw_networkx_nodes(
562
  G, pos,
563
  node_color=[G.nodes[n]['color'] for n in G.nodes],
564
+ node_size=800
 
565
  )
566
  nx.draw_networkx_edges(
567
  G, pos,
 
571
  arrowsize=20
572
  )
573
 
574
+ # === 5. 绘制中文标签(关键修改点)===
575
  nx.draw_networkx_labels(
576
  G, pos,
577
+ labels={n: G.nodes[n]['label'] for n in G.nodes},
578
  font_size=10,
579
+ font_family='Noto Sans CJK SC' # 显式指定字体
 
580
  )
 
 
581
  nx.draw_networkx_edge_labels(
582
  G, pos,
583
+ edge_labels=nx.get_edge_attributes(G, 'label'),
584
  font_size=8,
585
+ font_family='Noto Sans CJK SC' # 显式指定字体
586
  )
587
 
588
  plt.axis('off')
589
+
590
+ # === 6. 保存到临时文件 ===
591
+ temp_dir = tempfile.mkdtemp()
592
+ file_path = os.path.join(temp_dir, "kg.png")
593
+ plt.savefig(file_path, bbox_inches='tight')
 
 
 
 
 
594
  plt.close()
595
+
596
+ return file_path
597
+
598
  except Exception as e:
599
+ logging.error(f"生成知识图谱图片失败: {str(e)}")
600
  return None
601
 
602
 
 
603
  def process_file(file, model_type="bert"):
604
  try:
605
  with open(file.name, 'rb') as f:
 
751
 
752
  with gr.Tab("📄 文本分析"):
753
  input_text = gr.Textbox(lines=6, label="输入文本")
754
+ model_type = gr.Radio(["bert", "chatglm"], value="bert", label="选择模型")
755
+ btn = gr.Button("开始分析")
756
+ out1 = gr.Textbox(label="识别实体")
757
+ out2 = gr.Textbox(label="识别关系")
758
+ out3 = gr.HTML(label="知识图谱") # 使用HTML组件显示文本格式的知识图谱
759
+ out4 = gr.Textbox(label="耗时")
760
+ btn.click(fn=process_text, inputs=[input_text, model_type], outputs=[out1, out2, out3, out4])
761
 
762
  with gr.Tab("🗂 文件分析"):
763
  file_input = gr.File(file_types=[".txt", ".json"])
 
792
  import_output = gr.Textbox(label="导入日志")
793
  import_btn.click(fn=lambda: import_dataset(dataset_path.value), outputs=import_output)
794
 
 
 
 
 
 
 
795
  demo.launch(server_name="0.0.0.0", server_port=7860)