chen666-666 commited on
Commit
1e5ba7c
·
1 Parent(s): b8c346d

更新代码:添加新的功能

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # 默认忽略的文件
2
+ /shelf/
3
+ /workspace.xml
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/wechat-ner-re.iml" filepath="$PROJECT_DIR$/.idea/wechat-ner-re.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/wechat-ner-re.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import BertTokenizer, BertModel
3
  import gradio as gr
4
  import re
5
  import os
@@ -8,11 +8,16 @@ import pandas as pd
8
  import chardet
9
  from pyvis.network import Network
10
  import networkx as nx
 
11
 
12
- # 初始化模型
13
- model_name = "bert-base-chinese"
14
- tokenizer = BertTokenizer.from_pretrained(model_name)
15
- model = BertModel.from_pretrained(model_name)
 
 
 
 
16
 
17
  # 知识图谱数据存储
18
  knowledge_graph = {
@@ -56,7 +61,7 @@ def visualize_kg():
56
  font={'size': 14})
57
  seen_edges.add(edge_key)
58
 
59
- net.set_options("""
60
  {
61
  "nodes": {
62
  "scaling": {
@@ -88,7 +93,16 @@ def visualize_kg():
88
  return f'<div class="kg-graph">{html}</div>'
89
 
90
 
91
- def ner(text):
 
 
 
 
 
 
 
 
 
92
  name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
93
  id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
94
 
@@ -120,60 +134,13 @@ def ner(text):
120
  })
121
  occupied.add((start, end))
122
 
123
- return sorted(entities, key=lambda x: x["start"])
124
-
125
-
126
- def re_extract(entities, text):
127
- relations = []
128
- triggers = {
129
- "recommend": ["推荐", "引荐", "建议", "找"],
130
- "send_to": ["发送", "转发", "发给", "抄送"],
131
- "mention": ["提到", "提及", "@", "说"]
132
- }
133
-
134
- for i in range(len(entities)):
135
- for j in range(max(0, i - 2), min(len(entities), i + 3)):
136
- if i == j:
137
- continue
138
-
139
- ctx_start = entities[i]["end"]
140
- ctx_end = entities[j]["start"]
141
- context = text[ctx_start:ctx_end].strip()
142
-
143
- if text.startswith('@', entities[i]["start"] - 1):
144
- relations.append({
145
- "head": entities[i]["text"],
146
- "tail": entities[j]["text"],
147
- "relation": "mention"
148
- })
149
- continue
150
-
151
- relation_type = "knows"
152
- for rel_type, keywords in triggers.items():
153
- if any(kw in context for kw in keywords):
154
- relation_type = rel_type
155
- break
156
-
157
- relations.append({
158
- "head": entities[i]["text"],
159
- "tail": entities[j]["text"],
160
- "relation": relation_type
161
- })
162
-
163
- unique_relations = []
164
- seen = set()
165
- for rel in relations:
166
- key = (rel["head"], rel["tail"], rel["relation"])
167
- if key not in seen:
168
- unique_relations.append(rel)
169
- seen.add(key)
170
-
171
- return unique_relations
172
 
173
 
174
- def process_text(text):
175
  try:
176
- entities = ner(text)
177
  relations = re_extract(entities, text)
178
  update_knowledge_graph(entities, relations)
179
 
@@ -187,10 +154,10 @@ def process_text(text):
187
  )
188
  kg_html = visualize_kg()
189
 
190
- return entity_output, relation_output, gr.HTML(kg_html)
191
 
192
  except Exception as e:
193
- return f"处理出错: {str(e)}", "", gr.HTML()
194
 
195
 
196
  def detect_encoding(file_path):
@@ -198,7 +165,7 @@ def detect_encoding(file_path):
198
  return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
199
 
200
 
201
- def process_file(file):
202
  ext = os.path.splitext(file.name)[-1].lower()
203
  full_text = ""
204
 
@@ -238,10 +205,10 @@ def process_file(file):
238
  else:
239
  return f"❌ 不支持的文件类型: {ext}", "", gr.HTML()
240
 
241
- return process_text(full_text)
242
 
243
  except Exception as e:
244
- return f"❌ 文件处理错误: {str(e)}", "", gr.HTML()
245
 
246
 
247
  # Gradio UI
@@ -267,12 +234,14 @@ with gr.Blocks(css=css) as demo:
267
  gr.Markdown("### 直接输入聊天内容")
268
  input_text = gr.Textbox(label="输入内容", lines=8,
269
  placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
 
270
  analyze_btn = gr.Button("开始分析", variant="primary")
271
 
272
  with gr.Row():
273
  entity_output = gr.Textbox(label="识别实体", interactive=False)
274
  relation_output = gr.Textbox(label="发现关系", interactive=False)
275
  kg_display = gr.HTML(label="知识图谱", elem_classes="kg-container")
 
276
 
277
  analyze_btn.click(
278
  fn=lambda: gr.update(interactive=False),
@@ -280,8 +249,8 @@ with gr.Blocks(css=css) as demo:
280
  outputs=analyze_btn
281
  ).then(
282
  fn=process_text,
283
- inputs=[input_text],
284
- outputs=[entity_output, relation_output, kg_display]
285
  ).then(
286
  fn=lambda: gr.update(interactive=True),
287
  inputs=None,
@@ -291,17 +260,19 @@ with gr.Blocks(css=css) as demo:
291
  with gr.Tab("📁 文件分析"):
292
  gr.Markdown("### 上传聊天记录文件")
293
  file_input = gr.File(label="选择文件", file_types=[".txt", ".json", ".jsonl", ".csv"])
 
294
  file_btn = gr.Button("分析文件", variant="primary")
295
 
296
  with gr.Row():
297
  file_entity = gr.Textbox(label="识别实体", interactive=False)
298
  file_relation = gr.Textbox(label="发现关系", interactive=False)
299
  file_kg = gr.HTML(elem_classes="kg-container")
 
300
 
301
  file_btn.click(
302
  fn=process_file,
303
- inputs=[file_input],
304
- outputs=[file_entity, file_relation, file_kg]
305
  )
306
 
307
  with gr.Tab("🗺️ 完整图谱"):
 
1
  import torch
2
+ from transformers import BertTokenizer, BertModel, LlamaTokenizer, LlamaForCausalLM
3
  import gradio as gr
4
  import re
5
  import os
 
8
  import chardet
9
  from pyvis.network import Network
10
  import networkx as nx
11
+ import time
12
 
13
+ # 初始化 BERT 和 LLaMA 2 模型
14
+ bert_model_name = "bert-base-chinese"
15
+ bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
16
+ bert_model = BertModel.from_pretrained(bert_model_name)
17
+
18
+ llama_model_name = "meta-llama/Llama-2-7b-chat-hf"
19
+ llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_name)
20
+ llama_model = LlamaForCausalLM.from_pretrained(llama_model_name)
21
 
22
  # 知识图谱数据存储
23
  knowledge_graph = {
 
61
  font={'size': 14})
62
  seen_edges.add(edge_key)
63
 
64
+ net.set_options("""
65
  {
66
  "nodes": {
67
  "scaling": {
 
93
  return f'<div class="kg-graph">{html}</div>'
94
 
95
 
96
+ def ner(text, model_type="bert"):
97
+ # 选择模型进行处理
98
+ start_time = time.time()
99
+ if model_type == "bert":
100
+ tokenizer = bert_tokenizer
101
+ model = bert_model
102
+ elif model_type == "llama":
103
+ tokenizer = llama_tokenizer
104
+ model = llama_model
105
+
106
  name_pattern = r"([赵钱孙李周吴郑王冯陈褚卫蒋沈韩杨朱秦尤许何吕施张孔曹严华金魏陶姜][\u4e00-\u9fa5]{1,2})(?![的地得啦啊呀])"
107
  id_pattern = r"(?<!\S)([a-zA-Z_][a-zA-Z0-9_]{4,})(?![\u4e00-\u9fa5])"
108
 
 
134
  })
135
  occupied.add((start, end))
136
 
137
+ processing_time = time.time() - start_time
138
+ return entities, processing_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
 
141
+ def process_text(text, model_type="bert"):
142
  try:
143
+ entities, processing_time = ner(text, model_type=model_type)
144
  relations = re_extract(entities, text)
145
  update_knowledge_graph(entities, relations)
146
 
 
154
  )
155
  kg_html = visualize_kg()
156
 
157
+ return entity_output, relation_output, gr.HTML(kg_html), f"处理时间:{processing_time:.2f}秒"
158
 
159
  except Exception as e:
160
+ return f"处理出错: {str(e)}", "", gr.HTML(), ""
161
 
162
 
163
  def detect_encoding(file_path):
 
165
  return chardet.detect(f.read(4096))['encoding'] or 'utf-8'
166
 
167
 
168
+ def process_file(file, model_type="bert"):
169
  ext = os.path.splitext(file.name)[-1].lower()
170
  full_text = ""
171
 
 
205
  else:
206
  return f"❌ 不支持的文件类型: {ext}", "", gr.HTML()
207
 
208
+ return process_text(full_text, model_type)
209
 
210
  except Exception as e:
211
+ return f"❌ 文件处理错误: {str(e)}", "", gr.HTML(), ""
212
 
213
 
214
  # Gradio UI
 
234
  gr.Markdown("### 直接输入聊天内容")
235
  input_text = gr.Textbox(label="输入内容", lines=8,
236
  placeholder="示例:张三@李四 请把需求文档_v2发送给王五")
237
+ model_type = gr.Radio(["bert", "llama"], label="选择模型", value="bert")
238
  analyze_btn = gr.Button("开始分析", variant="primary")
239
 
240
  with gr.Row():
241
  entity_output = gr.Textbox(label="识别实体", interactive=False)
242
  relation_output = gr.Textbox(label="发现关系", interactive=False)
243
  kg_display = gr.HTML(label="知识图谱", elem_classes="kg-container")
244
+ time_output = gr.Textbox(label="处理时间", interactive=False)
245
 
246
  analyze_btn.click(
247
  fn=lambda: gr.update(interactive=False),
 
249
  outputs=analyze_btn
250
  ).then(
251
  fn=process_text,
252
+ inputs=[input_text, model_type],
253
+ outputs=[entity_output, relation_output, kg_display, time_output]
254
  ).then(
255
  fn=lambda: gr.update(interactive=True),
256
  inputs=None,
 
260
  with gr.Tab("📁 文件分析"):
261
  gr.Markdown("### 上传聊天记录文件")
262
  file_input = gr.File(label="选择文件", file_types=[".txt", ".json", ".jsonl", ".csv"])
263
+ file_model_type = gr.Radio(["bert", "llama"], label="选择模型", value="bert")
264
  file_btn = gr.Button("分析文件", variant="primary")
265
 
266
  with gr.Row():
267
  file_entity = gr.Textbox(label="识别实体", interactive=False)
268
  file_relation = gr.Textbox(label="发现关系", interactive=False)
269
  file_kg = gr.HTML(elem_classes="kg-container")
270
+ file_time = gr.Textbox(label="处理时间", interactive=False)
271
 
272
  file_btn.click(
273
  fn=process_file,
274
+ inputs=[file_input, file_model_type],
275
+ outputs=[file_entity, file_relation, file_kg, file_time]
276
  )
277
 
278
  with gr.Tab("🗺️ 完整图谱"):