gewei20 commited on
Commit
4b03607
·
verified ·
1 Parent(s): 671d8e3

Create app_chromadb.py

Browse files
Files changed (1) hide show
  1. app_chromadb.py +176 -0
app_chromadb.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_chromadb.py
2
+ # 这个文件只定义类和方法,它是一个被 app.py 调用的“模块”。
3
+
4
+ import os
5
+ import requests
6
+ import hashlib
7
+ from pathlib import Path
8
+ from typing import List, Dict
9
+ import time
10
+ from datetime import datetime
11
+ import uuid
12
+
13
+ class MarkdownKnowledgeBase:
14
+ """
15
+ 负责处理 Markdown 文件、与 SiliconFlow API 交互以获取向量,
16
+ 并将最终数据存入 ChromaDB 的核心类。
17
+ """
18
+ def __init__(self, api_token: str, chroma_collection, base_url: str = "https://api.siliconflow.cn/v1"):
19
+ self.api_token = api_token
20
+ self.base_url = base_url
21
+ self.headers = {
22
+ "Authorization": f"Bearer {api_token}",
23
+ "Content-Type": "application/json"
24
+ }
25
+ self.collection = chroma_collection
26
+
27
+ def get_embeddings(self, texts: List[str], model: str = "BAAI/bge-m3") -> List[List[float]]:
28
+ """
29
+ 调用 SiliconFlow API 获取文本的嵌入向量。
30
+ """
31
+ url = f"{self.base_url}/embeddings"
32
+ embeddings = []
33
+ batch_size = 32
34
+ total_batches = (len(texts) + batch_size - 1) // batch_size
35
+
36
+ for batch_idx in range(0, len(texts), batch_size):
37
+ batch = texts[batch_idx:batch_idx + batch_size]
38
+ current_batch = batch_idx // batch_size + 1
39
+ print(f"处理批次 {current_batch}/{total_batches} ({len(batch)} 个文本)")
40
+ payload = {"model": model, "input": batch, "encoding_format": "float"}
41
+ max_retries = 3
42
+ for attempt in range(max_retries):
43
+ try:
44
+ response = requests.post(url, json=payload, headers=self.headers, timeout=60)
45
+ response.raise_for_status()
46
+ result = response.json()
47
+ if 'data' in result:
48
+ embeddings.extend([item['embedding'] for item in result['data']])
49
+ break
50
+ else:
51
+ if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
52
+ except requests.exceptions.RequestException as e:
53
+ print(f" ✗ 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
54
+ if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
55
+ time.sleep(0.1)
56
+ return embeddings
57
+
58
+ def build_knowledge_base(self, folder_path: str, chunk_size: int = 4096, overlap: int = 400,
59
+ max_files: int = None, sample_mode: str = "random"):
60
+ """
61
+ 扫描、分块、向量化并最终将数据存入 ChromaDB。
62
+ """
63
+ print("扫描文件并生成文本块...")
64
+ md_files = self._scan_files(folder_path)
65
+ if max_files and len(md_files) > max_files:
66
+ md_files = self._sample_files(md_files, max_files, sample_mode)
67
+
68
+ all_chunks, all_metadatas = [], []
69
+ for file_path in md_files:
70
+ file_info = self._read_content(file_path)
71
+ if not file_info or len(file_info['content'].strip()) < 50:
72
+ continue
73
+ chunks = self._chunk_text(file_info['content'], chunk_size, overlap)
74
+ for j, chunk in enumerate(chunks):
75
+ if len(chunk.strip()) > 20:
76
+ all_chunks.append(chunk)
77
+ all_metadatas.append({'file_name': file_info['file_name'], 'source': file_info['file_path']})
78
+
79
+ if not all_chunks:
80
+ print("没有有效的文本块可供处理。")
81
+ return
82
+
83
+ print(f"总共生成 {len(all_chunks)} 个文本块,开始获取向量...")
84
+ embeddings = self.get_embeddings(all_chunks)
85
+
86
+ valid_indices = [i for i, emb in enumerate(embeddings) if emb]
87
+ if not valid_indices:
88
+ print("未能成功获取任何向量,无法添加到知识库。")
89
+ return
90
+
91
+ valid_embeddings = [embeddings[i] for i in valid_indices]
92
+ valid_chunks = [all_chunks[i] for i in valid_indices]
93
+ valid_metadatas = [all_metadatas[i] for i in valid_indices]
94
+ ids = [str(uuid.uuid4()) for _ in valid_chunks]
95
+
96
+ print(f"获取向量完成,正在将 {len(ids)} 个有效条目批量写入 ChromaDB...")
97
+
98
+ if ids: # 确保有内容可以添加
99
+ self.collection.add(
100
+ embeddings=valid_embeddings,
101
+ documents=valid_chunks,
102
+ metadatas=valid_metadatas,
103
+ ids=ids
104
+ )
105
+
106
+ print("知识库构建并存入 ChromaDB 成功!")
107
+
108
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
109
+ """
110
+ 在 ChromaDB 中执行向量搜索。
111
+ """
112
+ print(f"在 ChromaDB 中搜索: '{query}'")
113
+ query_embedding = self.get_embeddings([query])[0]
114
+ if not query_embedding:
115
+ return []
116
+
117
+ results = self.collection.query(
118
+ query_embeddings=[query_embedding],
119
+ n_results=top_k
120
+ )
121
+
122
+ formatted_results = []
123
+ if results and results['ids'][0]:
124
+ for i in range(len(results['ids'][0])):
125
+ formatted_results.append({
126
+ "id": results['ids'][0][i],
127
+ "content": results['documents'][0][i],
128
+ "metadata": results['metadatas'][0][i],
129
+ "distance": results['distances'][0][i]
130
+ })
131
+ return formatted_results
132
+
133
+ # --- 私有辅助方法 ---
134
+ def _scan_files(self, folder_path: str) -> List[str]:
135
+ md_files = []
136
+ folder = Path(folder_path)
137
+ if not folder.exists(): return []
138
+ for md_file in folder.rglob("*.md"):
139
+ if md_file.is_file(): md_files.append(str(md_file.resolve()))
140
+ return md_files
141
+
142
+ def _read_content(self, file_path: str) -> Dict:
143
+ try:
144
+ encodings = ['utf-8', 'utf-8-sig', 'gbk', 'cp1252', 'latin1']
145
+ content = None
146
+ for encoding in encodings:
147
+ try:
148
+ with open(file_path, 'r', encoding=encoding) as file:
149
+ content = file.read()
150
+ break
151
+ except UnicodeDecodeError: continue
152
+ if content is None: return None
153
+ return {'file_name': os.path.basename(file_path), 'content': content, 'file_path': file_path}
154
+ except Exception:
155
+ return None
156
+
157
+ def _sample_files(self, md_files: List[str], max_files: int, mode: str) -> List[str]:
158
+ if mode == "random":
159
+ import random
160
+ return random.sample(md_files, min(len(md_files), max_files))
161
+ elif mode == "largest":
162
+ return sorted(md_files, key=lambda f: os.path.getsize(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
163
+ elif mode == "recent":
164
+ return sorted(md_files, key=lambda f: os.path.getmtime(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
165
+ return md_files[:max_files]
166
+
167
+ def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
168
+ if len(text) <= chunk_size: return [text]
169
+ chunks = []
170
+ start = 0
171
+ while start < len(text):
172
+ end = start + chunk_size
173
+ chunk = text[start:end]
174
+ chunks.append(chunk)
175
+ start += chunk_size - overlap
176
+ return chunks