# app_chromadb.py # 这个文件只定义类和方法,它是一个被 app.py 调用的“模块”。 import os import requests import hashlib from pathlib import Path from typing import List, Dict import time from datetime import datetime import uuid class MarkdownKnowledgeBase: """ 负责处理 Markdown 文件、与 SiliconFlow API 交互以获取向量, 并将最终数据存入 ChromaDB 的核心类。 """ def __init__(self, api_token: str, chroma_collection, base_url: str = "https://api.siliconflow.cn/v1"): self.api_token = api_token self.base_url = base_url self.headers = { "Authorization": f"Bearer {api_token}", "Content-Type": "application/json" } self.collection = chroma_collection def get_embeddings(self, texts: List[str], model: str = "BAAI/bge-m3") -> List[List[float]]: """ 调用 SiliconFlow API 获取文本的嵌入向量。 """ url = f"{self.base_url}/embeddings" embeddings = [] batch_size = 32 total_batches = (len(texts) + batch_size - 1) // batch_size for batch_idx in range(0, len(texts), batch_size): batch = texts[batch_idx:batch_idx + batch_size] current_batch = batch_idx // batch_size + 1 print(f"处理批次 {current_batch}/{total_batches} ({len(batch)} 个文本)") payload = {"model": model, "input": batch, "encoding_format": "float"} max_retries = 3 for attempt in range(max_retries): try: response = requests.post(url, json=payload, headers=self.headers, timeout=60) response.raise_for_status() result = response.json() if 'data' in result: embeddings.extend([item['embedding'] for item in result['data']]) break else: if attempt == max_retries - 1: embeddings.extend([[] for _ in batch]) except requests.exceptions.RequestException as e: print(f" ✗ 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") if attempt == max_retries - 1: embeddings.extend([[] for _ in batch]) time.sleep(0.1) return embeddings def build_knowledge_base(self, folder_path: str, chunk_size: int = 4096, overlap: int = 400, max_files: int = None, sample_mode: str = "random"): """ 扫描、分块、向量化并最终将数据存入 ChromaDB。 """ print("扫描文件并生成文本块...") md_files = self._scan_files(folder_path) if max_files and len(md_files) > max_files: md_files = self._sample_files(md_files, max_files, sample_mode) all_chunks, all_metadatas = [], [] for file_path in md_files: file_info = self._read_content(file_path) if not file_info or len(file_info['content'].strip()) < 50: continue chunks = self._chunk_text(file_info['content'], chunk_size, overlap) for j, chunk in enumerate(chunks): if len(chunk.strip()) > 20: all_chunks.append(chunk) all_metadatas.append({'file_name': file_info['file_name'], 'source': file_info['file_path']}) if not all_chunks: print("没有有效的文本块可供处理。") return print(f"总共生成 {len(all_chunks)} 个文本块,开始获取向量...") embeddings = self.get_embeddings(all_chunks) valid_indices = [i for i, emb in enumerate(embeddings) if emb] if not valid_indices: print("未能成功获取任何向量,无法添加到知识库。") return valid_embeddings = [embeddings[i] for i in valid_indices] valid_chunks = [all_chunks[i] for i in valid_indices] valid_metadatas = [all_metadatas[i] for i in valid_indices] ids = [str(uuid.uuid4()) for _ in valid_chunks] print(f"获取向量完成,正在将 {len(ids)} 个有效条目批量写入 ChromaDB...") if ids: # 确保有内容可以添加 self.collection.add( embeddings=valid_embeddings, documents=valid_chunks, metadatas=valid_metadatas, ids=ids ) print("知识库构建并存入 ChromaDB 成功!") def search(self, query: str, top_k: int = 5) -> List[Dict]: """ 在 ChromaDB 中执行向量搜索。 """ print(f"在 ChromaDB 中搜索: '{query}'") query_embedding = self.get_embeddings([query])[0] if not query_embedding: return [] results = self.collection.query( query_embeddings=[query_embedding], n_results=top_k ) formatted_results = [] if results and results['ids'][0]: for i in range(len(results['ids'][0])): formatted_results.append({ "id": results['ids'][0][i], "content": results['documents'][0][i], "metadata": results['metadatas'][0][i], "distance": results['distances'][0][i] }) return formatted_results # --- 私有辅助方法 --- def _scan_files(self, folder_path: str) -> List[str]: md_files = [] folder = Path(folder_path) if not folder.exists(): return [] for md_file in folder.rglob("*.md"): if md_file.is_file(): md_files.append(str(md_file.resolve())) return md_files def _read_content(self, file_path: str) -> Dict: try: encodings = ['utf-8', 'utf-8-sig', 'gbk', 'cp1252', 'latin1'] content = None for encoding in encodings: try: with open(file_path, 'r', encoding=encoding) as file: content = file.read() break except UnicodeDecodeError: continue if content is None: return None return {'file_name': os.path.basename(file_path), 'content': content, 'file_path': file_path} except Exception: return None def _sample_files(self, md_files: List[str], max_files: int, mode: str) -> List[str]: if mode == "random": import random return random.sample(md_files, min(len(md_files), max_files)) elif mode == "largest": return sorted(md_files, key=lambda f: os.path.getsize(f) if os.path.exists(f) else 0, reverse=True)[:max_files] elif mode == "recent": return sorted(md_files, key=lambda f: os.path.getmtime(f) if os.path.exists(f) else 0, reverse=True)[:max_files] return md_files[:max_files] def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]: if len(text) <= chunk_size: return [text] chunks = [] start = 0 while start < len(text): end = start + chunk_size chunk = text[start:end] chunks.append(chunk) start += chunk_size - overlap return chunks