Spaces:
Running
Running
Upload 2 files
Browse files- app.py +33 -44
- db_config.ini +6 -0
app.py
CHANGED
@@ -10,13 +10,13 @@ import time
|
|
10 |
# ======================== 数据库模块 ========================
|
11 |
import pymysql
|
12 |
from configparser import ConfigParser
|
|
|
13 |
|
14 |
-
|
15 |
def get_db_connection():
|
16 |
config = ConfigParser()
|
17 |
config.read('db_config.ini')
|
18 |
-
|
19 |
-
return pymysql.connect(
|
20 |
host=config.get('mysql', 'host'),
|
21 |
user=config.get('mysql', 'user'),
|
22 |
password=config.get('mysql', 'password'),
|
@@ -25,7 +25,10 @@ def get_db_connection():
|
|
25 |
charset=config.get('mysql', 'charset', fallback='utf8mb4'),
|
26 |
cursorclass=pymysql.cursors.DictCursor
|
27 |
)
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
def save_to_db(table, data):
|
31 |
conn = None
|
@@ -122,30 +125,27 @@ def visualize_kg_text():
|
|
122 |
|
123 |
# ======================== 实体识别(NER) ========================
|
124 |
def merge_adjacent_entities(entities):
|
125 |
-
|
126 |
-
|
127 |
-
if not merged:
|
128 |
-
merged.append(entity)
|
129 |
-
continue
|
130 |
|
|
|
|
|
131 |
last = merged[-1]
|
132 |
# 合并相邻的同类型实体
|
133 |
if (entity["type"] == last["type"] and
|
134 |
-
entity["start"] == last["end"]
|
135 |
-
|
136 |
-
|
137 |
-
"text": last["text"] + entity["text"],
|
138 |
-
"type": last["type"],
|
139 |
-
"start": last["start"],
|
140 |
-
"end": entity["end"]
|
141 |
-
}
|
142 |
else:
|
143 |
merged.append(entity)
|
|
|
144 |
return merged
|
145 |
|
146 |
|
147 |
def ner(text, model_type="bert"):
|
148 |
start_time = time.time()
|
|
|
|
|
149 |
if model_type == "chatglm" and use_chatglm:
|
150 |
try:
|
151 |
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
@@ -155,47 +155,35 @@ def ner(text, model_type="bert"):
|
|
155 |
if isinstance(response, tuple):
|
156 |
response = response[0]
|
157 |
|
158 |
-
# 增强 JSON 解析
|
159 |
try:
|
160 |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
161 |
entities = json.loads(json_str)
|
162 |
-
|
163 |
-
valid_entities = []
|
164 |
-
for ent in entities:
|
165 |
-
if all(k in ent for k in ("text", "type", "start", "end")):
|
166 |
-
valid_entities.append(ent)
|
167 |
return valid_entities, time.time() - start_time
|
168 |
except Exception as e:
|
169 |
-
print(f"JSON
|
170 |
return [], time.time() - start_time
|
171 |
except Exception as e:
|
172 |
-
print(f"ChatGLM
|
173 |
return [], time.time() - start_time
|
174 |
|
175 |
-
#
|
|
|
176 |
raw_results = []
|
177 |
-
max_len = 510 # 安全一点,留一点空余
|
178 |
-
text_chunks = [text[i:i + max_len] for i in range(0, len(text), max_len)]
|
179 |
-
|
180 |
for idx, chunk in enumerate(text_chunks):
|
181 |
chunk_results = bert_ner_pipeline(chunk)
|
182 |
-
# 修正每个 chunk 里识别的实体在整体文本中的位置
|
183 |
for r in chunk_results:
|
184 |
-
r["start"] += idx *
|
185 |
-
r["end"] += idx *
|
186 |
raw_results.extend(chunk_results)
|
187 |
|
188 |
-
entities = [
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
"type": mapped_type
|
196 |
-
})
|
197 |
-
|
198 |
-
# 执行合并处理
|
199 |
entities = merge_adjacent_entities(entities)
|
200 |
return entities, time.time() - start_time
|
201 |
|
@@ -349,7 +337,7 @@ def process_file(file, model_type="bert"):
|
|
349 |
text = content.decode(encoding)
|
350 |
except UnicodeDecodeError:
|
351 |
# 尝试常见中文编码
|
352 |
-
for enc in ['gb18030', 'utf-16', 'big5']:
|
353 |
try:
|
354 |
text = content.decode(enc)
|
355 |
break
|
@@ -363,6 +351,7 @@ def process_file(file, model_type="bert"):
|
|
363 |
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
364 |
|
365 |
|
|
|
366 |
# ======================== 模型评估与自动标注 ========================
|
367 |
def convert_telegram_json_to_eval_format(path):
|
368 |
with open(path, encoding="utf-8") as f:
|
|
|
10 |
# ======================== 数据库模块 ========================
|
11 |
import pymysql
|
12 |
from configparser import ConfigParser
|
13 |
+
from contextlib import contextmanager
|
14 |
|
15 |
+
@contextmanager
|
16 |
def get_db_connection():
|
17 |
config = ConfigParser()
|
18 |
config.read('db_config.ini')
|
19 |
+
conn = pymysql.connect(
|
|
|
20 |
host=config.get('mysql', 'host'),
|
21 |
user=config.get('mysql', 'user'),
|
22 |
password=config.get('mysql', 'password'),
|
|
|
25 |
charset=config.get('mysql', 'charset', fallback='utf8mb4'),
|
26 |
cursorclass=pymysql.cursors.DictCursor
|
27 |
)
|
28 |
+
try:
|
29 |
+
yield conn
|
30 |
+
finally:
|
31 |
+
conn.close()
|
32 |
|
33 |
def save_to_db(table, data):
|
34 |
conn = None
|
|
|
125 |
|
126 |
# ======================== 实体识别(NER) ========================
|
127 |
def merge_adjacent_entities(entities):
|
128 |
+
if not entities:
|
129 |
+
return entities
|
|
|
|
|
|
|
130 |
|
131 |
+
merged = [entities[0]]
|
132 |
+
for entity in entities[1:]:
|
133 |
last = merged[-1]
|
134 |
# 合并相邻的同类型实体
|
135 |
if (entity["type"] == last["type"] and
|
136 |
+
entity["start"] == last["end"]):
|
137 |
+
last["text"] += entity["text"]
|
138 |
+
last["end"] = entity["end"]
|
|
|
|
|
|
|
|
|
|
|
139 |
else:
|
140 |
merged.append(entity)
|
141 |
+
|
142 |
return merged
|
143 |
|
144 |
|
145 |
def ner(text, model_type="bert"):
|
146 |
start_time = time.time()
|
147 |
+
|
148 |
+
# 如果使用的是 ChatGLM 模型,执行 ChatGLM 的NER
|
149 |
if model_type == "chatglm" and use_chatglm:
|
150 |
try:
|
151 |
prompt = f"""请从以下文本中识别所有实体,严格按照JSON列表格式返回,每个实体包含text、type、start、end字段:
|
|
|
155 |
if isinstance(response, tuple):
|
156 |
response = response[0]
|
157 |
|
|
|
158 |
try:
|
159 |
json_str = re.search(r'\[.*\]', response, re.DOTALL).group()
|
160 |
entities = json.loads(json_str)
|
161 |
+
valid_entities = [ent for ent in entities if all(k in ent for k in ("text", "type", "start", "end"))]
|
|
|
|
|
|
|
|
|
162 |
return valid_entities, time.time() - start_time
|
163 |
except Exception as e:
|
164 |
+
print(f"JSON解析失败: {e}")
|
165 |
return [], time.time() - start_time
|
166 |
except Exception as e:
|
167 |
+
print(f"ChatGLM调用失败: {e}")
|
168 |
return [], time.time() - start_time
|
169 |
|
170 |
+
# 使用BERT NER
|
171 |
+
text_chunks = [text[i:i + 510] for i in range(0, len(text), 510)] # 安全分段
|
172 |
raw_results = []
|
|
|
|
|
|
|
173 |
for idx, chunk in enumerate(text_chunks):
|
174 |
chunk_results = bert_ner_pipeline(chunk)
|
|
|
175 |
for r in chunk_results:
|
176 |
+
r["start"] += idx * 510
|
177 |
+
r["end"] += idx * 510
|
178 |
raw_results.extend(chunk_results)
|
179 |
|
180 |
+
entities = [{
|
181 |
+
"text": r['word'].replace(' ', ''),
|
182 |
+
"start": r['start'],
|
183 |
+
"end": r['end'],
|
184 |
+
"type": LABEL_MAPPING.get(r['entity_group'], r['entity_group'])
|
185 |
+
} for r in raw_results]
|
186 |
+
|
|
|
|
|
|
|
|
|
187 |
entities = merge_adjacent_entities(entities)
|
188 |
return entities, time.time() - start_time
|
189 |
|
|
|
337 |
text = content.decode(encoding)
|
338 |
except UnicodeDecodeError:
|
339 |
# 尝试常见中文编码
|
340 |
+
for enc in ['gb18030', 'utf-16', 'big5'] :
|
341 |
try:
|
342 |
text = content.decode(enc)
|
343 |
break
|
|
|
351 |
return f"❌ 文件处理错误: {str(e)}", "", "", ""
|
352 |
|
353 |
|
354 |
+
|
355 |
# ======================== 模型评估与自动标注 ========================
|
356 |
def convert_telegram_json_to_eval_format(path):
|
357 |
with open(path, encoding="utf-8") as f:
|
db_config.ini
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[mysql]
|
2 |
+
host = localhost
|
3 |
+
user = root
|
4 |
+
password = 123456
|
5 |
+
database = entity_kg
|
6 |
+
charset = utf8mb4
|