Spaces:
Running
Running
File size: 7,414 Bytes
4b03607 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
# app_chromadb.py
# 这个文件只定义类和方法,它是一个被 app.py 调用的“模块”。
import os
import requests
import hashlib
from pathlib import Path
from typing import List, Dict
import time
from datetime import datetime
import uuid
class MarkdownKnowledgeBase:
"""
负责处理 Markdown 文件、与 SiliconFlow API 交互以获取向量,
并将最终数据存入 ChromaDB 的核心类。
"""
def __init__(self, api_token: str, chroma_collection, base_url: str = "https://api.siliconflow.cn/v1"):
self.api_token = api_token
self.base_url = base_url
self.headers = {
"Authorization": f"Bearer {api_token}",
"Content-Type": "application/json"
}
self.collection = chroma_collection
def get_embeddings(self, texts: List[str], model: str = "BAAI/bge-m3") -> List[List[float]]:
"""
调用 SiliconFlow API 获取文本的嵌入向量。
"""
url = f"{self.base_url}/embeddings"
embeddings = []
batch_size = 32
total_batches = (len(texts) + batch_size - 1) // batch_size
for batch_idx in range(0, len(texts), batch_size):
batch = texts[batch_idx:batch_idx + batch_size]
current_batch = batch_idx // batch_size + 1
print(f"处理批次 {current_batch}/{total_batches} ({len(batch)} 个文本)")
payload = {"model": model, "input": batch, "encoding_format": "float"}
max_retries = 3
for attempt in range(max_retries):
try:
response = requests.post(url, json=payload, headers=self.headers, timeout=60)
response.raise_for_status()
result = response.json()
if 'data' in result:
embeddings.extend([item['embedding'] for item in result['data']])
break
else:
if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
except requests.exceptions.RequestException as e:
print(f" ✗ 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
time.sleep(0.1)
return embeddings
def build_knowledge_base(self, folder_path: str, chunk_size: int = 4096, overlap: int = 400,
max_files: int = None, sample_mode: str = "random"):
"""
扫描、分块、向量化并最终将数据存入 ChromaDB。
"""
print("扫描文件并生成文本块...")
md_files = self._scan_files(folder_path)
if max_files and len(md_files) > max_files:
md_files = self._sample_files(md_files, max_files, sample_mode)
all_chunks, all_metadatas = [], []
for file_path in md_files:
file_info = self._read_content(file_path)
if not file_info or len(file_info['content'].strip()) < 50:
continue
chunks = self._chunk_text(file_info['content'], chunk_size, overlap)
for j, chunk in enumerate(chunks):
if len(chunk.strip()) > 20:
all_chunks.append(chunk)
all_metadatas.append({'file_name': file_info['file_name'], 'source': file_info['file_path']})
if not all_chunks:
print("没有有效的文本块可供处理。")
return
print(f"总共生成 {len(all_chunks)} 个文本块,开始获取向量...")
embeddings = self.get_embeddings(all_chunks)
valid_indices = [i for i, emb in enumerate(embeddings) if emb]
if not valid_indices:
print("未能成功获取任何向量,无法添加到知识库。")
return
valid_embeddings = [embeddings[i] for i in valid_indices]
valid_chunks = [all_chunks[i] for i in valid_indices]
valid_metadatas = [all_metadatas[i] for i in valid_indices]
ids = [str(uuid.uuid4()) for _ in valid_chunks]
print(f"获取向量完成,正在将 {len(ids)} 个有效条目批量写入 ChromaDB...")
if ids: # 确保有内容可以添加
self.collection.add(
embeddings=valid_embeddings,
documents=valid_chunks,
metadatas=valid_metadatas,
ids=ids
)
print("知识库构建并存入 ChromaDB 成功!")
def search(self, query: str, top_k: int = 5) -> List[Dict]:
"""
在 ChromaDB 中执行向量搜索。
"""
print(f"在 ChromaDB 中搜索: '{query}'")
query_embedding = self.get_embeddings([query])[0]
if not query_embedding:
return []
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k
)
formatted_results = []
if results and results['ids'][0]:
for i in range(len(results['ids'][0])):
formatted_results.append({
"id": results['ids'][0][i],
"content": results['documents'][0][i],
"metadata": results['metadatas'][0][i],
"distance": results['distances'][0][i]
})
return formatted_results
# --- 私有辅助方法 ---
def _scan_files(self, folder_path: str) -> List[str]:
md_files = []
folder = Path(folder_path)
if not folder.exists(): return []
for md_file in folder.rglob("*.md"):
if md_file.is_file(): md_files.append(str(md_file.resolve()))
return md_files
def _read_content(self, file_path: str) -> Dict:
try:
encodings = ['utf-8', 'utf-8-sig', 'gbk', 'cp1252', 'latin1']
content = None
for encoding in encodings:
try:
with open(file_path, 'r', encoding=encoding) as file:
content = file.read()
break
except UnicodeDecodeError: continue
if content is None: return None
return {'file_name': os.path.basename(file_path), 'content': content, 'file_path': file_path}
except Exception:
return None
def _sample_files(self, md_files: List[str], max_files: int, mode: str) -> List[str]:
if mode == "random":
import random
return random.sample(md_files, min(len(md_files), max_files))
elif mode == "largest":
return sorted(md_files, key=lambda f: os.path.getsize(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
elif mode == "recent":
return sorted(md_files, key=lambda f: os.path.getmtime(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
return md_files[:max_files]
def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
if len(text) <= chunk_size: return [text]
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
start += chunk_size - overlap
return chunks
|