File size: 7,414 Bytes
4b03607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# app_chromadb.py
# 这个文件只定义类和方法,它是一个被 app.py 调用的“模块”。

import os
import requests
import hashlib
from pathlib import Path
from typing import List, Dict
import time
from datetime import datetime
import uuid

class MarkdownKnowledgeBase:
    """
    负责处理 Markdown 文件、与 SiliconFlow API 交互以获取向量,
    并将最终数据存入 ChromaDB 的核心类。
    """
    def __init__(self, api_token: str, chroma_collection, base_url: str = "https://api.siliconflow.cn/v1"):
        self.api_token = api_token
        self.base_url = base_url
        self.headers = {
            "Authorization": f"Bearer {api_token}",
            "Content-Type": "application/json"
        }
        self.collection = chroma_collection
        
    def get_embeddings(self, texts: List[str], model: str = "BAAI/bge-m3") -> List[List[float]]:
        """
        调用 SiliconFlow API 获取文本的嵌入向量。
        """
        url = f"{self.base_url}/embeddings"
        embeddings = []
        batch_size = 32
        total_batches = (len(texts) + batch_size - 1) // batch_size
        
        for batch_idx in range(0, len(texts), batch_size):
            batch = texts[batch_idx:batch_idx + batch_size]
            current_batch = batch_idx // batch_size + 1
            print(f"处理批次 {current_batch}/{total_batches} ({len(batch)} 个文本)")
            payload = {"model": model, "input": batch, "encoding_format": "float"}
            max_retries = 3
            for attempt in range(max_retries):
                try:
                    response = requests.post(url, json=payload, headers=self.headers, timeout=60)
                    response.raise_for_status()
                    result = response.json()
                    if 'data' in result:
                        embeddings.extend([item['embedding'] for item in result['data']])
                        break 
                    else:
                        if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
                except requests.exceptions.RequestException as e:
                    print(f"  ✗ 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}")
                    if attempt == max_retries - 1: embeddings.extend([[] for _ in batch])
            time.sleep(0.1)
        return embeddings

    def build_knowledge_base(self, folder_path: str, chunk_size: int = 4096, overlap: int = 400, 
                           max_files: int = None, sample_mode: str = "random"):
        """
        扫描、分块、向量化并最终将数据存入 ChromaDB。
        """
        print("扫描文件并生成文本块...")
        md_files = self._scan_files(folder_path)
        if max_files and len(md_files) > max_files:
            md_files = self._sample_files(md_files, max_files, sample_mode)
        
        all_chunks, all_metadatas = [], []
        for file_path in md_files:
            file_info = self._read_content(file_path)
            if not file_info or len(file_info['content'].strip()) < 50:
                continue
            chunks = self._chunk_text(file_info['content'], chunk_size, overlap)
            for j, chunk in enumerate(chunks):
                if len(chunk.strip()) > 20:
                    all_chunks.append(chunk)
                    all_metadatas.append({'file_name': file_info['file_name'], 'source': file_info['file_path']})
        
        if not all_chunks:
            print("没有有效的文本块可供处理。")
            return
            
        print(f"总共生成 {len(all_chunks)} 个文本块,开始获取向量...")
        embeddings = self.get_embeddings(all_chunks)
        
        valid_indices = [i for i, emb in enumerate(embeddings) if emb]
        if not valid_indices:
             print("未能成功获取任何向量,无法添加到知识库。")
             return

        valid_embeddings = [embeddings[i] for i in valid_indices]
        valid_chunks = [all_chunks[i] for i in valid_indices]
        valid_metadatas = [all_metadatas[i] for i in valid_indices]
        ids = [str(uuid.uuid4()) for _ in valid_chunks]
        
        print(f"获取向量完成,正在将 {len(ids)} 个有效条目批量写入 ChromaDB...")
        
        if ids: # 确保有内容可以添加
            self.collection.add(
                embeddings=valid_embeddings,
                documents=valid_chunks,
                metadatas=valid_metadatas,
                ids=ids
            )
        
        print("知识库构建并存入 ChromaDB 成功!")

    def search(self, query: str, top_k: int = 5) -> List[Dict]:
        """
        在 ChromaDB 中执行向量搜索。
        """
        print(f"在 ChromaDB 中搜索: '{query}'")
        query_embedding = self.get_embeddings([query])[0]
        if not query_embedding:
            return []

        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=top_k
        )
        
        formatted_results = []
        if results and results['ids'][0]:
            for i in range(len(results['ids'][0])):
                formatted_results.append({
                    "id": results['ids'][0][i],
                    "content": results['documents'][0][i],
                    "metadata": results['metadatas'][0][i],
                    "distance": results['distances'][0][i] 
                })
        return formatted_results

    # --- 私有辅助方法 ---
    def _scan_files(self, folder_path: str) -> List[str]:
        md_files = []
        folder = Path(folder_path)
        if not folder.exists(): return []
        for md_file in folder.rglob("*.md"):
            if md_file.is_file(): md_files.append(str(md_file.resolve()))
        return md_files

    def _read_content(self, file_path: str) -> Dict:
        try:
            encodings = ['utf-8', 'utf-8-sig', 'gbk', 'cp1252', 'latin1']
            content = None
            for encoding in encodings:
                try:
                    with open(file_path, 'r', encoding=encoding) as file:
                        content = file.read()
                        break
                except UnicodeDecodeError: continue
            if content is None: return None
            return {'file_name': os.path.basename(file_path), 'content': content, 'file_path': file_path}
        except Exception:
            return None
        
    def _sample_files(self, md_files: List[str], max_files: int, mode: str) -> List[str]:
         if mode == "random":
             import random
             return random.sample(md_files, min(len(md_files), max_files))
         elif mode == "largest":
             return sorted(md_files, key=lambda f: os.path.getsize(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
         elif mode == "recent":
             return sorted(md_files, key=lambda f: os.path.getmtime(f) if os.path.exists(f) else 0, reverse=True)[:max_files]
         return md_files[:max_files]

    def _chunk_text(self, text: str, chunk_size: int, overlap: int) -> List[str]:
        if len(text) <= chunk_size: return [text]
        chunks = []
        start = 0
        while start < len(text):
            end = start + chunk_size
            chunk = text[start:end]
            chunks.append(chunk)
            start += chunk_size - overlap
        return chunks