chen666-666 commited on
Commit
2543020
·
verified ·
1 Parent(s): 3822c77

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +33 -44
  2. db_config.ini +6 -0
app.py CHANGED
@@ -10,13 +10,13 @@ import time
10
  # ======================== 数据库模块 ========================
11
  import pymysql
12
  from configparser import ConfigParser
 
13
 
14
-
15
  def get_db_connection():
16
  config = ConfigParser()
17
  config.read('db_config.ini')
18
-
19
- return pymysql.connect(
20
  host=config.get('mysql', 'host'),
21
  user=config.get('mysql', 'user'),
22
  password=config.get('mysql', 'password'),
@@ -25,7 +25,10 @@ def get_db_connection():
25
  charset=config.get('mysql', 'charset', fallback='utf8mb4'),
26
  cursorclass=pymysql.cursors.DictCursor
27
  )
28
-
 
 
 
29
 
30
  def save_to_db(table, data):
31
  conn = None
@@ -122,30 +125,27 @@ def visualize_kg_text():
122
 
123
  # ======================== 实体识别(NER) ========================
124
  def merge_adjacent_entities(entities):
125
- merged = []
126
- for entity in entities:
127
- if not merged:
128
- merged.append(entity)
129
- continue
130
 
 
 
131
  last = merged[-1]
132
  # 合并相邻的同类型实体
133
  if (entity["type"] == last["type"] and
134
- entity["start"] == last["end"] and
135
- entity["text"] not in last["text"]):
136
- merged[-1] = {
137
- "text": last["text"] + entity["text"],
138
- "type": last["type"],
139
- "start": last["start"],
140
- "end": entity["end"]
141
- }
142
  else:
143
  merged.append(entity)
 
144
  return merged
145
 
146
 
147
  def ner(text, model_type="bert"):
148
  start_time = time.time()
 
 
149
  if model_type == "chatglm" and use_chatglm:
150
  try:
151
  prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
@@ -155,47 +155,35 @@ def ner(text, model_type="bert"):
155
  if isinstance(response, tuple):
156
  response = response[0]
157
 
158
- # 增强 JSON 解析
159
  try:
160
  json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
161
  entities = json.loads(json_str)
162
- # 验证字段
163
- valid_entities = []
164
- for ent in entities:
165
- if all(k in ent for k in ("text", "type", "start", "end")):
166
- valid_entities.append(ent)
167
  return valid_entities, time.time() - start_time
168
  except Exception as e:
169
- print(f"JSON 解析失败: {e}")
170
  return [], time.time() - start_time
171
  except Exception as e:
172
- print(f"ChatGLM 调用失败:{e}")
173
  return [], time.time() - start_time
174
 
175
- # 使用微调的 BERT 中文 NER 模型
 
176
  raw_results = []
177
- max_len = 510 # 安全一点,留一点空余
178
- text_chunks = [text[i:i + max_len] for i in range(0, len(text), max_len)]
179
-
180
  for idx, chunk in enumerate(text_chunks):
181
  chunk_results = bert_ner_pipeline(chunk)
182
- # 修正每个 chunk 里识别的实体在整体文本中的位置
183
  for r in chunk_results:
184
- r["start"] += idx * max_len
185
- r["end"] += idx * max_len
186
  raw_results.extend(chunk_results)
187
 
188
- entities = []
189
- for r in raw_results:
190
- mapped_type = LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
191
- entities.append({
192
- "text": r['word'].replace(' ', ''),
193
- "start": r['start'],
194
- "end": r['end'],
195
- "type": mapped_type
196
- })
197
-
198
- # 执行合并处理
199
  entities = merge_adjacent_entities(entities)
200
  return entities, time.time() - start_time
201
 
@@ -349,7 +337,7 @@ def process_file(file, model_type="bert"):
349
  text = content.decode(encoding)
350
  except UnicodeDecodeError:
351
  # 尝试常见中文编码
352
- for enc in ['gb18030', 'utf-16', 'big5']:
353
  try:
354
  text = content.decode(enc)
355
  break
@@ -363,6 +351,7 @@ def process_file(file, model_type="bert"):
363
  return f"❌ 文件处理错误: {str(e)}", "", "", ""
364
 
365
 
 
366
  # ======================== 模型评估与自动标注 ========================
367
  def convert_telegram_json_to_eval_format(path):
368
  with open(path, encoding="utf-8") as f:
 
10
  # ======================== 数据库模块 ========================
11
  import pymysql
12
  from configparser import ConfigParser
13
+ from contextlib import contextmanager
14
 
15
+ @contextmanager
16
  def get_db_connection():
17
  config = ConfigParser()
18
  config.read('db_config.ini')
19
+ conn = pymysql.connect(
 
20
  host=config.get('mysql', 'host'),
21
  user=config.get('mysql', 'user'),
22
  password=config.get('mysql', 'password'),
 
25
  charset=config.get('mysql', 'charset', fallback='utf8mb4'),
26
  cursorclass=pymysql.cursors.DictCursor
27
  )
28
+ try:
29
+ yield conn
30
+ finally:
31
+ conn.close()
32
 
33
  def save_to_db(table, data):
34
  conn = None
 
125
 
126
  # ======================== 实体识别(NER) ========================
127
  def merge_adjacent_entities(entities):
128
+ if not entities:
129
+ return entities
 
 
 
130
 
131
+ merged = [entities[0]]
132
+ for entity in entities[1:]:
133
  last = merged[-1]
134
  # 合并相邻的同类型实体
135
  if (entity["type"] == last["type"] and
136
+ entity["start"] == last["end"]):
137
+ last["text"] += entity["text"]
138
+ last["end"] = entity["end"]
 
 
 
 
 
139
  else:
140
  merged.append(entity)
141
+
142
  return merged
143
 
144
 
145
  def ner(text, model_type="bert"):
146
  start_time = time.time()
147
+
148
+ # 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
149
  if model_type == "chatglm" and use_chatglm:
150
  try:
151
  prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
 
155
  if isinstance(response, tuple):
156
  response = response[0]
157
 
 
158
  try:
159
  json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
160
  entities = json.loads(json_str)
161
+ valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
 
 
 
 
162
  return valid_entities, time.time() - start_time
163
  except Exception as e:
164
+ print(f"JSON解析失败: {e}")
165
  return [], time.time() - start_time
166
  except Exception as e:
167
+ print(f"ChatGLM调用失败: {e}")
168
  return [], time.time() - start_time
169
 
170
+ # 使用BERT NER
171
+ text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
172
  raw_results = []
 
 
 
173
  for idx, chunk in enumerate(text_chunks):
174
  chunk_results = bert_ner_pipeline(chunk)
 
175
  for r in chunk_results:
176
+ r["start"] += idx * 510
177
+ r["end"] += idx * 510
178
  raw_results.extend(chunk_results)
179
 
180
+ entities = [{
181
+ "text": r['word'].replace(' ', ''),
182
+ "start": r['start'],
183
+ "end": r['end'],
184
+ "type": LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
185
+ } for r in raw_results]
186
+
 
 
 
 
187
  entities = merge_adjacent_entities(entities)
188
  return entities, time.time() - start_time
189
 
 
337
  text = content.decode(encoding)
338
  except UnicodeDecodeError:
339
  # 尝试常见中文编码
340
+ for enc in ['gb18030', 'utf-16', 'big5'] :
341
  try:
342
  text = content.decode(enc)
343
  break
 
351
  return f"❌ 文件处理错误: {str(e)}", "", "", ""
352
 
353
 
354
+
355
  # ======================== 模型评估与自动标注 ========================
356
  def convert_telegram_json_to_eval_format(path):
357
  with open(path, encoding="utf-8") as f:
db_config.ini ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [mysql]
2
+ host = localhost
3
+ user = root
4
+ password = 123456
5
+ database = entity_kg
6
+ charset = utf8mb4