Gradio mcp
Browse files- app.py +676 -214
- core/__init__.py +1 -0
- core/chunker.py +302 -0
- core/document_parser.py +199 -0
- core/models.py +102 -0
- core/text_preprocessor.py +186 -0
- mcp_server.py +165 -70
- mcp_tools.py +0 -592
- mcp_tools/__init__.py +1 -0
- mcp_tools/generative_tool.py +342 -0
- mcp_tools/ingestion_tool.py +330 -0
- mcp_tools/search_tool.py +423 -0
- mcp_tools/utils.py +373 -0
- services/__init__.py +1 -0
- services/document_store_service.py +349 -0
- services/embedding_service.py +204 -0
- services/llm_service.py +285 -0
- services/ocr_service.py +324 -0
- services/vector_store_service.py +285 -0
app.py
CHANGED
@@ -1,254 +1,716 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
import asyncio
|
3 |
-
from pathlib import Path
|
4 |
-
import tempfile
|
5 |
import json
|
6 |
-
from typing import List, Dict, Any
|
7 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
# Handle imports based on how the app is run
|
12 |
-
try:
|
13 |
-
from mcp_server import mcp
|
14 |
-
MCP_AVAILABLE = True
|
15 |
-
except ImportError:
|
16 |
-
MCP_AVAILABLE = False
|
17 |
-
print("⚠️ MCP server not available, running in standalone mode")
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
#
|
22 |
logging.basicConfig(level=logging.INFO)
|
23 |
logger = logging.getLogger(__name__)
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
|
|
38 |
if file is None:
|
39 |
-
return "
|
40 |
|
41 |
try:
|
42 |
-
#
|
43 |
-
|
|
|
44 |
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
return (
|
48 |
-
f"✅
|
49 |
-
result["
|
50 |
-
|
51 |
-
gr.update(
|
|
|
|
|
52 |
)
|
53 |
else:
|
54 |
-
return
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
except Exception as e:
|
57 |
-
logger.error(f"Error
|
58 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
"""
|
62 |
-
if not
|
63 |
-
return "Please enter a
|
64 |
|
65 |
try:
|
66 |
-
|
67 |
-
result = await mcp_tools.process_web_content(url)
|
68 |
|
69 |
-
if result
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
else:
|
78 |
-
return f"❌
|
79 |
-
|
80 |
except Exception as e:
|
81 |
-
logger.error(f"
|
82 |
-
return f"❌ Error: {str(e)}"
|
83 |
|
84 |
-
|
85 |
-
"""
|
86 |
-
if not query:
|
87 |
-
return [], "Please enter a search query"
|
88 |
-
|
89 |
try:
|
90 |
-
#
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
result_cards.append(card)
|
108 |
-
|
109 |
-
global current_results
|
110 |
-
current_results = results
|
111 |
-
|
112 |
-
return result_cards, f"Found {len(results)} results"
|
113 |
else:
|
114 |
-
return
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
except Exception as e:
|
117 |
-
logger.error(f"
|
118 |
-
return
|
119 |
|
120 |
-
def
|
121 |
-
"""
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
label="Search Query",
|
187 |
-
placeholder="Enter your search query...",
|
188 |
-
lines=1
|
189 |
-
)
|
190 |
-
search_btn = gr.Button("Search", variant="primary")
|
191 |
-
search_status = gr.Textbox(label="Status", lines=1)
|
192 |
-
|
193 |
-
search_results = gr.Markdown(label="Search Results")
|
194 |
|
195 |
-
|
196 |
-
|
197 |
-
gr.
|
198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
|
219 |
-
""
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
-
#
|
241 |
if __name__ == "__main__":
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
import asyncio
|
247 |
-
asyncio.run(mcp.run())
|
248 |
-
else:
|
249 |
-
# Run as Gradio app
|
250 |
-
demo.launch(
|
251 |
-
server_name="0.0.0.0",
|
252 |
-
share=False,
|
253 |
-
show_error=True
|
254 |
-
)
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
import asyncio
|
|
|
|
|
4 |
import json
|
|
|
5 |
import logging
|
6 |
+
import tempfile
|
7 |
+
import uuid
|
8 |
+
from datetime import datetime
|
9 |
+
from pathlib import Path
|
10 |
+
from typing import List, Dict, Any, Optional
|
11 |
+
import nest_asyncio
|
12 |
|
13 |
+
# Apply nest_asyncio to handle nested event loops in Gradio
|
14 |
+
nest_asyncio.apply()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
+
# Import our custom modules
|
17 |
+
from mcp_tools.ingestion_tool import IngestionTool
|
18 |
+
from mcp_tools.search_tool import SearchTool
|
19 |
+
from mcp_tools.generative_tool import GenerativeTool
|
20 |
+
from services.vector_store_service import VectorStoreService
|
21 |
+
from services.document_store_service import DocumentStoreService
|
22 |
+
from services.embedding_service import EmbeddingService
|
23 |
+
from services.llm_service import LLMService
|
24 |
+
from services.ocr_service import OCRService
|
25 |
+
from core.models import SearchResult, Document
|
26 |
+
import config
|
27 |
|
28 |
+
# Setup logging
|
29 |
logging.basicConfig(level=logging.INFO)
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
32 |
+
class ContentOrganizerMCPServer:
|
33 |
+
def __init__(self):
|
34 |
+
# Initialize services
|
35 |
+
logger.info("Initializing Content Organizer MCP Server...")
|
36 |
+
|
37 |
+
self.vector_store = VectorStoreService()
|
38 |
+
self.document_store = DocumentStoreService()
|
39 |
+
self.embedding_service = EmbeddingService()
|
40 |
+
self.llm_service = LLMService()
|
41 |
+
self.ocr_service = OCRService()
|
42 |
+
|
43 |
+
# Initialize tools
|
44 |
+
self.ingestion_tool = IngestionTool(
|
45 |
+
vector_store=self.vector_store,
|
46 |
+
document_store=self.document_store,
|
47 |
+
embedding_service=self.embedding_service,
|
48 |
+
ocr_service=self.ocr_service
|
49 |
+
)
|
50 |
+
self.search_tool = SearchTool(
|
51 |
+
vector_store=self.vector_store,
|
52 |
+
embedding_service=self.embedding_service,
|
53 |
+
document_store=self.document_store
|
54 |
+
)
|
55 |
+
self.generative_tool = GenerativeTool(
|
56 |
+
llm_service=self.llm_service,
|
57 |
+
search_tool=self.search_tool
|
58 |
+
)
|
59 |
+
|
60 |
+
# Track processing status
|
61 |
+
self.processing_status = {}
|
62 |
+
|
63 |
+
# Document cache for quick access
|
64 |
+
self.document_cache = {}
|
65 |
+
|
66 |
+
logger.info("Content Organizer MCP Server initialized successfully!")
|
67 |
+
|
68 |
+
def run_async(self, coro):
|
69 |
+
"""Helper to run async functions in Gradio"""
|
70 |
+
try:
|
71 |
+
loop = asyncio.get_event_loop()
|
72 |
+
except RuntimeError:
|
73 |
+
loop = asyncio.new_event_loop()
|
74 |
+
asyncio.set_event_loop(loop)
|
75 |
+
|
76 |
+
if loop.is_running():
|
77 |
+
# If loop is already running, create a task
|
78 |
+
import concurrent.futures
|
79 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
80 |
+
future = executor.submit(asyncio.run, coro)
|
81 |
+
return future.result()
|
82 |
+
else:
|
83 |
+
return loop.run_until_complete(coro)
|
84 |
+
|
85 |
+
async def ingest_document_async(self, file_path: str, file_type: str) -> Dict[str, Any]:
|
86 |
+
"""MCP Tool: Ingest and process a document"""
|
87 |
+
try:
|
88 |
+
task_id = str(uuid.uuid4())
|
89 |
+
self.processing_status[task_id] = {"status": "processing", "progress": 0}
|
90 |
+
|
91 |
+
result = await self.ingestion_tool.process_document(file_path, file_type, task_id)
|
92 |
+
|
93 |
+
if result.get("success"):
|
94 |
+
self.processing_status[task_id] = {"status": "completed", "progress": 100}
|
95 |
+
# Update document cache
|
96 |
+
doc_id = result.get("document_id")
|
97 |
+
if doc_id:
|
98 |
+
doc = await self.document_store.get_document(doc_id)
|
99 |
+
if doc:
|
100 |
+
self.document_cache[doc_id] = doc
|
101 |
+
|
102 |
+
return result
|
103 |
+
else:
|
104 |
+
self.processing_status[task_id] = {"status": "failed", "error": result.get("error")}
|
105 |
+
return result
|
106 |
+
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Document ingestion failed: {str(e)}")
|
109 |
+
return {
|
110 |
+
"success": False,
|
111 |
+
"error": str(e),
|
112 |
+
"message": "Failed to process document"
|
113 |
+
}
|
114 |
+
|
115 |
+
async def get_document_content_async(self, document_id: str) -> Optional[str]:
|
116 |
+
"""Get document content by ID"""
|
117 |
+
try:
|
118 |
+
# Check cache first
|
119 |
+
if document_id in self.document_cache:
|
120 |
+
return self.document_cache[document_id].content
|
121 |
+
|
122 |
+
# Get from store
|
123 |
+
doc = await self.document_store.get_document(document_id)
|
124 |
+
if doc:
|
125 |
+
self.document_cache[document_id] = doc
|
126 |
+
return doc.content
|
127 |
+
|
128 |
+
return None
|
129 |
+
except Exception as e:
|
130 |
+
logger.error(f"Error getting document content: {str(e)}")
|
131 |
+
return None
|
132 |
+
|
133 |
+
async def semantic_search_async(self, query: str, top_k: int = 5, filters: Optional[Dict] = None) -> Dict[str, Any]:
|
134 |
+
"""MCP Tool: Perform semantic search"""
|
135 |
+
try:
|
136 |
+
results = await self.search_tool.search(query, top_k, filters)
|
137 |
+
return {
|
138 |
+
"success": True,
|
139 |
+
"query": query,
|
140 |
+
"results": [result.to_dict() for result in results],
|
141 |
+
"total_results": len(results)
|
142 |
+
}
|
143 |
+
except Exception as e:
|
144 |
+
logger.error(f"Semantic search failed: {str(e)}")
|
145 |
+
return {
|
146 |
+
"success": False,
|
147 |
+
"error": str(e),
|
148 |
+
"query": query,
|
149 |
+
"results": []
|
150 |
+
}
|
151 |
+
|
152 |
+
async def summarize_content_async(self, content: str = None, document_id: str = None, style: str = "concise") -> Dict[str, Any]:
|
153 |
+
"""MCP Tool: Summarize content or document"""
|
154 |
+
try:
|
155 |
+
# If document_id provided, get content from document
|
156 |
+
if document_id and document_id != "none":
|
157 |
+
content = await self.get_document_content_async(document_id)
|
158 |
+
if not content:
|
159 |
+
return {"success": False, "error": f"Document {document_id} not found"}
|
160 |
+
|
161 |
+
if not content or not content.strip():
|
162 |
+
return {"success": False, "error": "No content provided for summarization"}
|
163 |
+
|
164 |
+
# Truncate content if too long (for API limits)
|
165 |
+
max_content_length = 4000
|
166 |
+
if len(content) > max_content_length:
|
167 |
+
content = content[:max_content_length] + "..."
|
168 |
+
|
169 |
+
summary = await self.generative_tool.summarize(content, style)
|
170 |
+
return {
|
171 |
+
"success": True,
|
172 |
+
"summary": summary,
|
173 |
+
"original_length": len(content),
|
174 |
+
"summary_length": len(summary),
|
175 |
+
"style": style,
|
176 |
+
"document_id": document_id
|
177 |
+
}
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Summarization failed: {str(e)}")
|
180 |
+
return {
|
181 |
+
"success": False,
|
182 |
+
"error": str(e)
|
183 |
+
}
|
184 |
+
|
185 |
+
async def generate_tags_async(self, content: str = None, document_id: str = None, max_tags: int = 5) -> Dict[str, Any]:
|
186 |
+
"""MCP Tool: Generate tags for content"""
|
187 |
+
try:
|
188 |
+
# If document_id provided, get content from document
|
189 |
+
if document_id and document_id != "none":
|
190 |
+
content = await self.get_document_content_async(document_id)
|
191 |
+
if not content:
|
192 |
+
return {"success": False, "error": f"Document {document_id} not found"}
|
193 |
+
|
194 |
+
if not content or not content.strip():
|
195 |
+
return {"success": False, "error": "No content provided for tag generation"}
|
196 |
+
|
197 |
+
tags = await self.generative_tool.generate_tags(content, max_tags)
|
198 |
+
|
199 |
+
# Update document tags if document_id provided
|
200 |
+
if document_id and document_id != "none" and tags:
|
201 |
+
await self.document_store.update_document_metadata(document_id, {"tags": tags})
|
202 |
+
|
203 |
+
return {
|
204 |
+
"success": True,
|
205 |
+
"tags": tags,
|
206 |
+
"content_length": len(content),
|
207 |
+
"document_id": document_id
|
208 |
+
}
|
209 |
+
except Exception as e:
|
210 |
+
logger.error(f"Tag generation failed: {str(e)}")
|
211 |
+
return {
|
212 |
+
"success": False,
|
213 |
+
"error": str(e)
|
214 |
+
}
|
215 |
+
|
216 |
+
async def answer_question_async(self, question: str, context_filter: Optional[Dict] = None) -> Dict[str, Any]:
|
217 |
+
"""MCP Tool: Answer questions using RAG"""
|
218 |
+
try:
|
219 |
+
# Search for relevant context
|
220 |
+
search_results = await self.search_tool.search(question, top_k=5, filters=context_filter)
|
221 |
+
|
222 |
+
if not search_results:
|
223 |
+
return {
|
224 |
+
"success": False,
|
225 |
+
"error": "No relevant context found in your documents. Please make sure you have uploaded relevant documents.",
|
226 |
+
"question": question
|
227 |
+
}
|
228 |
+
|
229 |
+
# Generate answer using context
|
230 |
+
answer = await self.generative_tool.answer_question(question, search_results)
|
231 |
+
|
232 |
+
return {
|
233 |
+
"success": True,
|
234 |
+
"question": question,
|
235 |
+
"answer": answer,
|
236 |
+
"sources": [result.to_dict() for result in search_results],
|
237 |
+
"confidence": "high" if len(search_results) >= 3 else "medium"
|
238 |
+
}
|
239 |
+
except Exception as e:
|
240 |
+
logger.error(f"Question answering failed: {str(e)}")
|
241 |
+
return {
|
242 |
+
"success": False,
|
243 |
+
"error": str(e),
|
244 |
+
"question": question
|
245 |
+
}
|
246 |
+
|
247 |
+
def list_documents_sync(self, limit: int = 100, offset: int = 0) -> Dict[str, Any]:
|
248 |
+
"""List stored documents"""
|
249 |
+
try:
|
250 |
+
documents = self.run_async(self.document_store.list_documents(limit, offset))
|
251 |
+
return {
|
252 |
+
"success": True,
|
253 |
+
"documents": [doc.to_dict() for doc in documents],
|
254 |
+
"total": len(documents)
|
255 |
+
}
|
256 |
+
except Exception as e:
|
257 |
+
return {
|
258 |
+
"success": False,
|
259 |
+
"error": str(e)
|
260 |
+
}
|
261 |
+
|
262 |
+
# Initialize the MCP server
|
263 |
+
mcp_server = ContentOrganizerMCPServer()
|
264 |
+
|
265 |
+
# Helper functions
|
266 |
+
def get_document_list():
|
267 |
+
"""Get list of documents for display"""
|
268 |
+
try:
|
269 |
+
result = mcp_server.list_documents_sync(limit=100)
|
270 |
+
if result["success"]:
|
271 |
+
if result["documents"]:
|
272 |
+
doc_list = "📚 Documents in Library:\n\n"
|
273 |
+
for i, doc in enumerate(result["documents"], 1):
|
274 |
+
doc_list += f"{i}. {doc['filename']} (ID: {doc['id'][:8]}...)\n"
|
275 |
+
doc_list += f" Type: {doc['doc_type']}, Size: {doc['file_size']} bytes\n"
|
276 |
+
if doc.get('tags'):
|
277 |
+
doc_list += f" Tags: {', '.join(doc['tags'])}\n"
|
278 |
+
doc_list += f" Created: {doc['created_at'][:10]}\n\n"
|
279 |
+
return doc_list
|
280 |
+
else:
|
281 |
+
return "No documents in library yet. Upload some documents to get started!"
|
282 |
+
else:
|
283 |
+
return f"Error loading documents: {result['error']}"
|
284 |
+
except Exception as e:
|
285 |
+
return f"Error: {str(e)}"
|
286 |
|
287 |
+
def get_document_choices():
|
288 |
+
"""Get document choices for dropdown"""
|
289 |
+
try:
|
290 |
+
result = mcp_server.list_documents_sync(limit=100)
|
291 |
+
if result["success"] and result["documents"]:
|
292 |
+
choices = []
|
293 |
+
for doc in result["documents"]:
|
294 |
+
# Create label with filename and shortened ID
|
295 |
+
choice_label = f"{doc['filename']} ({doc['id'][:8]}...)"
|
296 |
+
# Use full document ID as the value
|
297 |
+
choices.append((choice_label, doc['id']))
|
298 |
+
|
299 |
+
logger.info(f"Generated {len(choices)} document choices")
|
300 |
+
return choices
|
301 |
+
return []
|
302 |
+
except Exception as e:
|
303 |
+
logger.error(f"Error getting document choices: {str(e)}")
|
304 |
+
return []
|
305 |
|
306 |
+
# Gradio Interface Functions
|
307 |
+
def upload_and_process_file(file):
|
308 |
+
"""Gradio interface for file upload"""
|
309 |
if file is None:
|
310 |
+
return "No file uploaded", "", get_document_list(), gr.update(choices=get_document_choices())
|
311 |
|
312 |
try:
|
313 |
+
# Get file path
|
314 |
+
file_path = file.name if hasattr(file, 'name') else str(file)
|
315 |
+
file_type = Path(file_path).suffix.lower()
|
316 |
|
317 |
+
logger.info(f"Processing file: {file_path}")
|
318 |
+
|
319 |
+
# Process document
|
320 |
+
result = mcp_server.run_async(mcp_server.ingest_document_async(file_path, file_type))
|
321 |
+
|
322 |
+
if result["success"]:
|
323 |
+
# Get updated document list and choices
|
324 |
+
doc_list = get_document_list()
|
325 |
+
doc_choices = get_document_choices()
|
326 |
+
|
327 |
return (
|
328 |
+
f"✅ Success: {result['message']}\nDocument ID: {result['document_id']}\nChunks created: {result['chunks_created']}",
|
329 |
+
result["document_id"],
|
330 |
+
doc_list,
|
331 |
+
gr.update(choices=doc_choices),
|
332 |
+
gr.update(choices=doc_choices),
|
333 |
+
gr.update(choices=doc_choices)
|
334 |
)
|
335 |
else:
|
336 |
+
return (
|
337 |
+
f"❌ Error: {result.get('error', 'Unknown error')}",
|
338 |
+
"",
|
339 |
+
get_document_list(),
|
340 |
+
gr.update(choices=get_document_choices()),
|
341 |
+
gr.update(choices=get_document_choices()),
|
342 |
+
gr.update(choices=get_document_choices())
|
343 |
+
)
|
344 |
except Exception as e:
|
345 |
+
logger.error(f"Error processing file: {str(e)}")
|
346 |
+
return (
|
347 |
+
f"❌ Error: {str(e)}",
|
348 |
+
"",
|
349 |
+
get_document_list(),
|
350 |
+
gr.update(choices=get_document_choices()),
|
351 |
+
gr.update(choices=get_document_choices()),
|
352 |
+
gr.update(choices=get_document_choices())
|
353 |
+
)
|
354 |
|
355 |
+
def perform_search(query, top_k):
|
356 |
+
"""Gradio interface for search"""
|
357 |
+
if not query.strip():
|
358 |
+
return "Please enter a search query"
|
359 |
|
360 |
try:
|
361 |
+
result = mcp_server.run_async(mcp_server.semantic_search_async(query, int(top_k)))
|
|
|
362 |
|
363 |
+
if result["success"]:
|
364 |
+
if result["results"]:
|
365 |
+
output = f"🔍 Found {result['total_results']} results for: '{query}'\n\n"
|
366 |
+
for i, res in enumerate(result["results"], 1):
|
367 |
+
output += f"Result {i}:\n"
|
368 |
+
output += f"📊 Relevance Score: {res['score']:.3f}\n"
|
369 |
+
output += f"📄 Content: {res['content'][:300]}...\n"
|
370 |
+
if 'document_filename' in res.get('metadata', {}):
|
371 |
+
output += f"📁 Source: {res['metadata']['document_filename']}\n"
|
372 |
+
output += f"🔗 Document ID: {res.get('document_id', 'Unknown')}\n"
|
373 |
+
output += "-" * 80 + "\n\n"
|
374 |
+
return output
|
375 |
+
else:
|
376 |
+
return f"No results found for: '{query}'\n\nMake sure you have uploaded relevant documents first."
|
377 |
else:
|
378 |
+
return f"❌ Search failed: {result['error']}"
|
|
|
379 |
except Exception as e:
|
380 |
+
logger.error(f"Search error: {str(e)}")
|
381 |
+
return f"❌ Error: {str(e)}"
|
382 |
|
383 |
+
def summarize_document(doc_choice, custom_text, style):
|
384 |
+
"""Gradio interface for summarization"""
|
|
|
|
|
|
|
385 |
try:
|
386 |
+
# Debug logging
|
387 |
+
logger.info(f"Summarize called with doc_choice: {doc_choice}, type: {type(doc_choice)}")
|
388 |
+
|
389 |
+
# Get document ID from dropdown choice
|
390 |
+
document_id = None
|
391 |
+
if doc_choice and doc_choice != "none" and doc_choice != "":
|
392 |
+
# When Gradio dropdown returns a choice, it returns the value part of the (label, value) tuple
|
393 |
+
document_id = doc_choice
|
394 |
+
logger.info(f"Using document ID: {document_id}")
|
395 |
+
|
396 |
+
# Use custom text if provided, otherwise use document
|
397 |
+
if custom_text and custom_text.strip():
|
398 |
+
logger.info("Using custom text for summarization")
|
399 |
+
result = mcp_server.run_async(mcp_server.summarize_content_async(content=custom_text, style=style))
|
400 |
+
elif document_id:
|
401 |
+
logger.info(f"Summarizing document: {document_id}")
|
402 |
+
result = mcp_server.run_async(mcp_server.summarize_content_async(document_id=document_id, style=style))
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
else:
|
404 |
+
return "Please select a document from the dropdown or enter text to summarize"
|
405 |
+
|
406 |
+
if result["success"]:
|
407 |
+
output = f"📝 Summary ({style} style):\n\n{result['summary']}\n\n"
|
408 |
+
output += f"📊 Statistics:\n"
|
409 |
+
output += f"- Original length: {result['original_length']} characters\n"
|
410 |
+
output += f"- Summary length: {result['summary_length']} characters\n"
|
411 |
+
output += f"- Compression ratio: {(1 - result['summary_length']/result['original_length'])*100:.1f}%\n"
|
412 |
+
if result.get('document_id'):
|
413 |
+
output += f"- Document ID: {result['document_id']}\n"
|
414 |
+
return output
|
415 |
+
else:
|
416 |
+
return f"❌ Summarization failed: {result['error']}"
|
417 |
except Exception as e:
|
418 |
+
logger.error(f"Summarization error: {str(e)}")
|
419 |
+
return f"❌ Error: {str(e)}"
|
420 |
|
421 |
+
def generate_tags_for_document(doc_choice, custom_text, max_tags):
|
422 |
+
"""Gradio interface for tag generation"""
|
423 |
+
try:
|
424 |
+
# Debug logging
|
425 |
+
logger.info(f"Generate tags called with doc_choice: {doc_choice}, type: {type(doc_choice)}")
|
426 |
+
|
427 |
+
# Get document ID from dropdown choice
|
428 |
+
document_id = None
|
429 |
+
if doc_choice and doc_choice != "none" and doc_choice != "":
|
430 |
+
# When Gradio dropdown returns a choice, it returns the value part of the (label, value) tuple
|
431 |
+
document_id = doc_choice
|
432 |
+
logger.info(f"Using document ID: {document_id}")
|
433 |
+
|
434 |
+
# Use custom text if provided, otherwise use document
|
435 |
+
if custom_text and custom_text.strip():
|
436 |
+
logger.info("Using custom text for tag generation")
|
437 |
+
result = mcp_server.run_async(mcp_server.generate_tags_async(content=custom_text, max_tags=int(max_tags)))
|
438 |
+
elif document_id:
|
439 |
+
logger.info(f"Generating tags for document: {document_id}")
|
440 |
+
result = mcp_server.run_async(mcp_server.generate_tags_async(document_id=document_id, max_tags=int(max_tags)))
|
441 |
+
else:
|
442 |
+
return "Please select a document from the dropdown or enter text to generate tags"
|
443 |
+
|
444 |
+
if result["success"]:
|
445 |
+
tags_str = ", ".join(result["tags"])
|
446 |
+
output = f"🏷️ Generated Tags:\n\n{tags_str}\n\n"
|
447 |
+
output += f"📊 Statistics:\n"
|
448 |
+
output += f"- Content length: {result['content_length']} characters\n"
|
449 |
+
output += f"- Number of tags: {len(result['tags'])}\n"
|
450 |
+
if result.get('document_id'):
|
451 |
+
output += f"- Document ID: {result['document_id']}\n"
|
452 |
+
output += f"\n✅ Tags have been saved to the document."
|
453 |
+
return output
|
454 |
+
else:
|
455 |
+
return f"❌ Tag generation failed: {result['error']}"
|
456 |
+
except Exception as e:
|
457 |
+
logger.error(f"Tag generation error: {str(e)}")
|
458 |
+
return f"❌ Error: {str(e)}"
|
459 |
+
|
460 |
+
def ask_question(question):
|
461 |
+
"""Gradio interface for Q&A"""
|
462 |
+
if not question.strip():
|
463 |
+
return "Please enter a question"
|
464 |
|
465 |
+
try:
|
466 |
+
result = mcp_server.run_async(mcp_server.answer_question_async(question))
|
467 |
+
|
468 |
+
if result["success"]:
|
469 |
+
output = f"❓ Question: {result['question']}\n\n"
|
470 |
+
output += f"💡 Answer:\n{result['answer']}\n\n"
|
471 |
+
output += f"🎯 Confidence: {result['confidence']}\n\n"
|
472 |
+
output += f"📚 Sources Used ({len(result['sources'])}):\n"
|
473 |
+
for i, source in enumerate(result['sources'], 1):
|
474 |
+
filename = source.get('metadata', {}).get('document_filename', 'Unknown')
|
475 |
+
output += f"\n{i}. 📄 {filename}\n"
|
476 |
+
output += f" 📝 Excerpt: {source['content'][:150]}...\n"
|
477 |
+
output += f" 📊 Relevance: {source['score']:.3f}\n"
|
478 |
+
return output
|
479 |
+
else:
|
480 |
+
return f"❌ {result.get('error', 'Failed to answer question')}"
|
481 |
+
except Exception as e:
|
482 |
+
return f"❌ Error: {str(e)}"
|
483 |
+
|
484 |
+
# Create Gradio Interface
|
485 |
+
def create_gradio_interface():
|
486 |
+
with gr.Blocks(title="🧠 Intelligent Content Organizer MCP Agent", theme=gr.themes.Soft()) as interface:
|
487 |
+
gr.Markdown("""
|
488 |
+
# 🧠 Intelligent Content Organizer MCP Agent
|
489 |
+
|
490 |
+
A powerful MCP (Model Context Protocol) server for intelligent content management with semantic search,
|
491 |
+
summarization, and Q&A capabilities powered by Anthropic Claude and Mistral AI.
|
492 |
+
|
493 |
+
## 🚀 Quick Start:
|
494 |
+
1. **Upload Documents** → Go to "📄 Upload Documents" tab
|
495 |
+
2. **Search Your Content** → Use "🔍 Search Documents" to find information
|
496 |
+
3. **Get Summaries** → Select any document in "📝 Summarize" tab
|
497 |
+
4. **Ask Questions** → Get answers from your documents in "❓ Ask Questions" tab
|
498 |
+
|
499 |
+
""")
|
500 |
+
|
501 |
+
# Shared components for document selection
|
502 |
+
doc_choices = gr.State(get_document_choices())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
+
with gr.Tabs():
|
505 |
+
# Document Library Tab
|
506 |
+
with gr.Tab("📚 Document Library"):
|
507 |
+
with gr.Row():
|
508 |
+
with gr.Column():
|
509 |
+
gr.Markdown("### Your Document Collection")
|
510 |
+
document_list = gr.Textbox(
|
511 |
+
label="Documents in Library",
|
512 |
+
value=get_document_list(),
|
513 |
+
lines=20,
|
514 |
+
interactive=False
|
515 |
+
)
|
516 |
+
refresh_btn = gr.Button("🔄 Refresh Library", variant="secondary")
|
517 |
+
|
518 |
+
refresh_btn.click(
|
519 |
+
fn=get_document_list,
|
520 |
+
outputs=[document_list]
|
521 |
+
)
|
522 |
|
523 |
+
# Document Ingestion Tab
|
524 |
+
with gr.Tab("📄 Upload Documents"):
|
525 |
+
with gr.Row():
|
526 |
+
with gr.Column():
|
527 |
+
gr.Markdown("### Add Documents to Your Library")
|
528 |
+
file_input = gr.File(
|
529 |
+
label="Select Document to Upload",
|
530 |
+
file_types=[".pdf", ".txt", ".docx", ".png", ".jpg", ".jpeg"],
|
531 |
+
type="filepath"
|
532 |
+
)
|
533 |
+
upload_btn = gr.Button("🚀 Process & Add to Library", variant="primary", size="lg")
|
534 |
+
with gr.Column():
|
535 |
+
upload_output = gr.Textbox(
|
536 |
+
label="Processing Result",
|
537 |
+
lines=6,
|
538 |
+
placeholder="Upload a document to see processing results..."
|
539 |
+
)
|
540 |
+
doc_id_output = gr.Textbox(
|
541 |
+
label="Document ID",
|
542 |
+
placeholder="Document ID will appear here after processing..."
|
543 |
+
)
|
544 |
+
|
545 |
+
# Hidden dropdowns for updating
|
546 |
+
doc_dropdown_sum = gr.Dropdown(label="Hidden", visible=False)
|
547 |
+
doc_dropdown_tag = gr.Dropdown(label="Hidden", visible=False)
|
548 |
+
|
549 |
+
upload_btn.click(
|
550 |
+
upload_and_process_file,
|
551 |
+
inputs=[file_input],
|
552 |
+
outputs=[upload_output, doc_id_output, document_list, doc_dropdown_sum, doc_dropdown_tag, doc_choices]
|
553 |
+
)
|
554 |
|
555 |
+
# Semantic Search Tab
|
556 |
+
with gr.Tab("🔍 Search Documents"):
|
557 |
+
with gr.Row():
|
558 |
+
with gr.Column(scale=1):
|
559 |
+
gr.Markdown("### Search Your Document Library")
|
560 |
+
search_query = gr.Textbox(
|
561 |
+
label="What are you looking for?",
|
562 |
+
placeholder="Enter your search query... (e.g., 'machine learning algorithms', 'quarterly revenue', 'project timeline')",
|
563 |
+
lines=2
|
564 |
+
)
|
565 |
+
search_top_k = gr.Slider(
|
566 |
+
label="Number of Results",
|
567 |
+
minimum=1,
|
568 |
+
maximum=20,
|
569 |
+
value=5,
|
570 |
+
step=1
|
571 |
+
)
|
572 |
+
search_btn = gr.Button("🔍 Search Library", variant="primary", size="lg")
|
573 |
+
with gr.Column(scale=2):
|
574 |
+
search_output = gr.Textbox(
|
575 |
+
label="Search Results",
|
576 |
+
lines=20,
|
577 |
+
placeholder="Search results will appear here..."
|
578 |
+
)
|
579 |
+
|
580 |
+
search_btn.click(
|
581 |
+
perform_search,
|
582 |
+
inputs=[search_query, search_top_k],
|
583 |
+
outputs=[search_output]
|
584 |
+
)
|
585 |
|
586 |
+
# Summarization Tab
|
587 |
+
with gr.Tab("📝 Summarize"):
|
588 |
+
with gr.Row():
|
589 |
+
with gr.Column():
|
590 |
+
gr.Markdown("### Generate Document Summaries")
|
591 |
+
|
592 |
+
with gr.Tab("From Library"):
|
593 |
+
doc_dropdown_sum = gr.Dropdown(
|
594 |
+
label="Select Document to Summarize",
|
595 |
+
choices=get_document_choices(),
|
596 |
+
value=None,
|
597 |
+
interactive=True,
|
598 |
+
allow_custom_value=False
|
599 |
+
)
|
600 |
+
|
601 |
+
with gr.Tab("Custom Text"):
|
602 |
+
summary_text = gr.Textbox(
|
603 |
+
label="Or Paste Text to Summarize",
|
604 |
+
placeholder="Paste any text here to summarize...",
|
605 |
+
lines=8
|
606 |
+
)
|
607 |
+
|
608 |
+
summary_style = gr.Dropdown(
|
609 |
+
label="Summary Style",
|
610 |
+
choices=["concise", "detailed", "bullet_points", "executive"],
|
611 |
+
value="concise",
|
612 |
+
info="Choose how you want the summary formatted"
|
613 |
+
)
|
614 |
+
summarize_btn = gr.Button("📝 Generate Summary", variant="primary", size="lg")
|
615 |
+
|
616 |
+
with gr.Column():
|
617 |
+
summary_output = gr.Textbox(
|
618 |
+
label="Generated Summary",
|
619 |
+
lines=20,
|
620 |
+
placeholder="Summary will appear here..."
|
621 |
+
)
|
622 |
+
|
623 |
+
summarize_btn.click(
|
624 |
+
summarize_document,
|
625 |
+
inputs=[doc_dropdown_sum, summary_text, summary_style],
|
626 |
+
outputs=[summary_output]
|
627 |
+
)
|
628 |
|
629 |
+
# Tag Generation Tab
|
630 |
+
with gr.Tab("🏷️ Generate Tags"):
|
631 |
+
with gr.Row():
|
632 |
+
with gr.Column():
|
633 |
+
gr.Markdown("### Auto-Generate Document Tags")
|
634 |
+
|
635 |
+
with gr.Tab("From Library"):
|
636 |
+
doc_dropdown_tag = gr.Dropdown(
|
637 |
+
label="Select Document to Tag",
|
638 |
+
choices=get_document_choices(),
|
639 |
+
value=None,
|
640 |
+
interactive=True,
|
641 |
+
allow_custom_value=False
|
642 |
+
)
|
643 |
+
|
644 |
+
with gr.Tab("Custom Text"):
|
645 |
+
tag_text = gr.Textbox(
|
646 |
+
label="Or Paste Text to Generate Tags",
|
647 |
+
placeholder="Paste any text here to generate tags...",
|
648 |
+
lines=8
|
649 |
+
)
|
650 |
+
|
651 |
+
max_tags = gr.Slider(
|
652 |
+
label="Number of Tags",
|
653 |
+
minimum=3,
|
654 |
+
maximum=15,
|
655 |
+
value=5,
|
656 |
+
step=1
|
657 |
+
)
|
658 |
+
tag_btn = gr.Button("🏷️ Generate Tags", variant="primary", size="lg")
|
659 |
+
|
660 |
+
with gr.Column():
|
661 |
+
tag_output = gr.Textbox(
|
662 |
+
label="Generated Tags",
|
663 |
+
lines=10,
|
664 |
+
placeholder="Tags will appear here..."
|
665 |
+
)
|
666 |
+
|
667 |
+
tag_btn.click(
|
668 |
+
generate_tags_for_document,
|
669 |
+
inputs=[doc_dropdown_tag, tag_text, max_tags],
|
670 |
+
outputs=[tag_output]
|
671 |
+
)
|
672 |
|
673 |
+
# Q&A Tab
|
674 |
+
with gr.Tab("❓ Ask Questions"):
|
675 |
+
with gr.Row():
|
676 |
+
with gr.Column():
|
677 |
+
gr.Markdown("""
|
678 |
+
### Ask Questions About Your Documents
|
679 |
+
|
680 |
+
The AI will search through all your uploaded documents to find relevant information
|
681 |
+
and provide comprehensive answers with sources.
|
682 |
+
""")
|
683 |
+
qa_question = gr.Textbox(
|
684 |
+
label="Your Question",
|
685 |
+
placeholder="Ask anything about your documents... (e.g., 'What are the key findings about renewable energy?', 'How much was spent on marketing last quarter?')",
|
686 |
+
lines=3
|
687 |
+
)
|
688 |
+
qa_btn = gr.Button("❓ Get Answer", variant="primary", size="lg")
|
689 |
+
|
690 |
+
with gr.Column():
|
691 |
+
qa_output = gr.Textbox(
|
692 |
+
label="AI Answer",
|
693 |
+
lines=20,
|
694 |
+
placeholder="Answer will appear here with sources..."
|
695 |
+
)
|
696 |
+
|
697 |
+
qa_btn.click(
|
698 |
+
ask_question,
|
699 |
+
inputs=[qa_question],
|
700 |
+
outputs=[qa_output]
|
701 |
+
)
|
702 |
+
|
703 |
+
# Auto-refresh document lists when switching tabs
|
704 |
+
interface.load(
|
705 |
+
fn=lambda: (get_document_list(), get_document_choices(), get_document_choices()),
|
706 |
+
outputs=[document_list, doc_dropdown_sum, doc_dropdown_tag]
|
707 |
+
)
|
708 |
+
|
709 |
+
return interface
|
710 |
|
711 |
+
# Create and launch the interface
|
712 |
if __name__ == "__main__":
|
713 |
+
interface = create_gradio_interface()
|
714 |
+
|
715 |
+
# Launch with proper configuration for Hugging Face Spaces
|
716 |
+
interface.launch(mcp_server=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Core module initialization
|
core/chunker.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
import re
|
4 |
+
from .models import Chunk
|
5 |
+
from .text_preprocessor import TextPreprocessor
|
6 |
+
import config
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class TextChunker:
|
11 |
+
def __init__(self):
|
12 |
+
self.config = config.config
|
13 |
+
self.preprocessor = TextPreprocessor()
|
14 |
+
|
15 |
+
self.chunk_size = self.config.CHUNK_SIZE
|
16 |
+
self.chunk_overlap = self.config.CHUNK_OVERLAP
|
17 |
+
|
18 |
+
def chunk_document(self, document_id: str, content: str, method: str = "recursive") -> List[Chunk]:
|
19 |
+
"""Chunk a document using the specified method"""
|
20 |
+
if not content:
|
21 |
+
return []
|
22 |
+
|
23 |
+
try:
|
24 |
+
if method == "recursive":
|
25 |
+
return self._recursive_chunk(document_id, content)
|
26 |
+
elif method == "sentence":
|
27 |
+
return self._sentence_chunk(document_id, content)
|
28 |
+
elif method == "paragraph":
|
29 |
+
return self._paragraph_chunk(document_id, content)
|
30 |
+
elif method == "fixed":
|
31 |
+
return self._fixed_chunk(document_id, content)
|
32 |
+
else:
|
33 |
+
logger.warning(f"Unknown chunking method: {method}, using recursive")
|
34 |
+
return self._recursive_chunk(document_id, content)
|
35 |
+
except Exception as e:
|
36 |
+
logger.error(f"Error chunking document: {str(e)}")
|
37 |
+
# Fallback to simple fixed chunking
|
38 |
+
return self._fixed_chunk(document_id, content)
|
39 |
+
|
40 |
+
def _recursive_chunk(self, document_id: str, content: str) -> List[Chunk]:
|
41 |
+
"""Recursively split text by different separators"""
|
42 |
+
chunks = []
|
43 |
+
|
44 |
+
# Define separators in order of preference
|
45 |
+
separators = [
|
46 |
+
"\n\n", # Paragraphs
|
47 |
+
"\n", # Lines
|
48 |
+
". ", # Sentences
|
49 |
+
", ", # Clauses
|
50 |
+
" " # Words
|
51 |
+
]
|
52 |
+
|
53 |
+
def split_text(text: str, separators: List[str], chunk_size: int) -> List[str]:
|
54 |
+
if len(text) <= chunk_size:
|
55 |
+
return [text] if text.strip() else []
|
56 |
+
|
57 |
+
if not separators:
|
58 |
+
# If no separators left, split by character
|
59 |
+
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
|
60 |
+
|
61 |
+
separator = separators[0]
|
62 |
+
remaining_separators = separators[1:]
|
63 |
+
|
64 |
+
splits = text.split(separator)
|
65 |
+
result = []
|
66 |
+
current_chunk = ""
|
67 |
+
|
68 |
+
for split in splits:
|
69 |
+
if len(current_chunk) + len(split) + len(separator) <= chunk_size:
|
70 |
+
if current_chunk:
|
71 |
+
current_chunk += separator + split
|
72 |
+
else:
|
73 |
+
current_chunk = split
|
74 |
+
else:
|
75 |
+
if current_chunk:
|
76 |
+
result.append(current_chunk)
|
77 |
+
|
78 |
+
if len(split) > chunk_size:
|
79 |
+
# Split is too big, need to split further
|
80 |
+
result.extend(split_text(split, remaining_separators, chunk_size))
|
81 |
+
current_chunk = ""
|
82 |
+
else:
|
83 |
+
current_chunk = split
|
84 |
+
|
85 |
+
if current_chunk:
|
86 |
+
result.append(current_chunk)
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
text_chunks = split_text(content, separators, self.chunk_size)
|
91 |
+
|
92 |
+
# Create chunk objects with overlap
|
93 |
+
for i, chunk_text in enumerate(text_chunks):
|
94 |
+
if not chunk_text.strip():
|
95 |
+
continue
|
96 |
+
|
97 |
+
# Calculate positions
|
98 |
+
start_pos = content.find(chunk_text)
|
99 |
+
if start_pos == -1:
|
100 |
+
start_pos = i * self.chunk_size
|
101 |
+
end_pos = start_pos + len(chunk_text)
|
102 |
+
|
103 |
+
# Add overlap from previous chunk if not the first chunk
|
104 |
+
if i > 0 and self.chunk_overlap > 0:
|
105 |
+
prev_chunk = text_chunks[i-1]
|
106 |
+
overlap_text = prev_chunk[-self.chunk_overlap:] if len(prev_chunk) > self.chunk_overlap else prev_chunk
|
107 |
+
chunk_text = overlap_text + " " + chunk_text
|
108 |
+
|
109 |
+
chunk = Chunk(
|
110 |
+
id=self._generate_chunk_id(document_id, i),
|
111 |
+
document_id=document_id,
|
112 |
+
content=chunk_text.strip(),
|
113 |
+
chunk_index=i,
|
114 |
+
start_pos=start_pos,
|
115 |
+
end_pos=end_pos,
|
116 |
+
metadata={
|
117 |
+
"chunk_method": "recursive",
|
118 |
+
"original_length": len(chunk_text),
|
119 |
+
"word_count": len(chunk_text.split())
|
120 |
+
}
|
121 |
+
)
|
122 |
+
chunks.append(chunk)
|
123 |
+
|
124 |
+
return chunks
|
125 |
+
|
126 |
+
def _sentence_chunk(self, document_id: str, content: str) -> List[Chunk]:
|
127 |
+
"""Chunk text by sentences"""
|
128 |
+
chunks = []
|
129 |
+
sentences = self.preprocessor.extract_sentences(content)
|
130 |
+
|
131 |
+
current_chunk = ""
|
132 |
+
chunk_index = 0
|
133 |
+
start_pos = 0
|
134 |
+
|
135 |
+
for sentence in sentences:
|
136 |
+
if len(current_chunk) + len(sentence) <= self.chunk_size:
|
137 |
+
if current_chunk:
|
138 |
+
current_chunk += " " + sentence
|
139 |
+
else:
|
140 |
+
current_chunk = sentence
|
141 |
+
start_pos = content.find(sentence)
|
142 |
+
else:
|
143 |
+
if current_chunk:
|
144 |
+
chunk = Chunk(
|
145 |
+
id=self._generate_chunk_id(document_id, chunk_index),
|
146 |
+
document_id=document_id,
|
147 |
+
content=current_chunk.strip(),
|
148 |
+
chunk_index=chunk_index,
|
149 |
+
start_pos=start_pos,
|
150 |
+
end_pos=start_pos + len(current_chunk),
|
151 |
+
metadata={
|
152 |
+
"chunk_method": "sentence",
|
153 |
+
"sentence_count": len(self.preprocessor.extract_sentences(current_chunk))
|
154 |
+
}
|
155 |
+
)
|
156 |
+
chunks.append(chunk)
|
157 |
+
chunk_index += 1
|
158 |
+
|
159 |
+
current_chunk = sentence
|
160 |
+
start_pos = content.find(sentence)
|
161 |
+
|
162 |
+
# Add final chunk
|
163 |
+
if current_chunk:
|
164 |
+
chunk = Chunk(
|
165 |
+
id=self._generate_chunk_id(document_id, chunk_index),
|
166 |
+
document_id=document_id,
|
167 |
+
content=current_chunk.strip(),
|
168 |
+
chunk_index=chunk_index,
|
169 |
+
start_pos=start_pos,
|
170 |
+
end_pos=start_pos + len(current_chunk),
|
171 |
+
metadata={
|
172 |
+
"chunk_method": "sentence",
|
173 |
+
"sentence_count": len(self.preprocessor.extract_sentences(current_chunk))
|
174 |
+
}
|
175 |
+
)
|
176 |
+
chunks.append(chunk)
|
177 |
+
|
178 |
+
return chunks
|
179 |
+
|
180 |
+
def _paragraph_chunk(self, document_id: str, content: str) -> List[Chunk]:
|
181 |
+
"""Chunk text by paragraphs"""
|
182 |
+
chunks = []
|
183 |
+
paragraphs = [p.strip() for p in content.split('\n\n') if p.strip()]
|
184 |
+
|
185 |
+
current_chunk = ""
|
186 |
+
chunk_index = 0
|
187 |
+
start_pos = 0
|
188 |
+
|
189 |
+
for paragraph in paragraphs:
|
190 |
+
if len(current_chunk) + len(paragraph) <= self.chunk_size:
|
191 |
+
if current_chunk:
|
192 |
+
current_chunk += "\n\n" + paragraph
|
193 |
+
else:
|
194 |
+
current_chunk = paragraph
|
195 |
+
start_pos = content.find(paragraph)
|
196 |
+
else:
|
197 |
+
if current_chunk:
|
198 |
+
chunk = Chunk(
|
199 |
+
id=self._generate_chunk_id(document_id, chunk_index),
|
200 |
+
document_id=document_id,
|
201 |
+
content=current_chunk.strip(),
|
202 |
+
chunk_index=chunk_index,
|
203 |
+
start_pos=start_pos,
|
204 |
+
end_pos=start_pos + len(current_chunk),
|
205 |
+
metadata={
|
206 |
+
"chunk_method": "paragraph",
|
207 |
+
"paragraph_count": len([p for p in current_chunk.split('\n\n') if p.strip()])
|
208 |
+
}
|
209 |
+
)
|
210 |
+
chunks.append(chunk)
|
211 |
+
chunk_index += 1
|
212 |
+
|
213 |
+
# If paragraph is too long, split it further
|
214 |
+
if len(paragraph) > self.chunk_size:
|
215 |
+
para_chunks = self._fixed_chunk(document_id, paragraph)
|
216 |
+
for pc in para_chunks:
|
217 |
+
pc.chunk_index = chunk_index
|
218 |
+
pc.id = self._generate_chunk_id(document_id, chunk_index)
|
219 |
+
chunks.append(pc)
|
220 |
+
chunk_index += 1
|
221 |
+
else:
|
222 |
+
current_chunk = paragraph
|
223 |
+
start_pos = content.find(paragraph)
|
224 |
+
|
225 |
+
# Add final chunk
|
226 |
+
if current_chunk:
|
227 |
+
chunk = Chunk(
|
228 |
+
id=self._generate_chunk_id(document_id, chunk_index),
|
229 |
+
document_id=document_id,
|
230 |
+
content=current_chunk.strip(),
|
231 |
+
chunk_index=chunk_index,
|
232 |
+
start_pos=start_pos,
|
233 |
+
end_pos=start_pos + len(current_chunk),
|
234 |
+
metadata={
|
235 |
+
"chunk_method": "paragraph",
|
236 |
+
"paragraph_count": len([p for p in current_chunk.split('\n\n') if p.strip()])
|
237 |
+
}
|
238 |
+
)
|
239 |
+
chunks.append(chunk)
|
240 |
+
|
241 |
+
return chunks
|
242 |
+
|
243 |
+
def _fixed_chunk(self, document_id: str, content: str) -> List[Chunk]:
|
244 |
+
"""Simple fixed-size chunking with overlap"""
|
245 |
+
chunks = []
|
246 |
+
|
247 |
+
for i in range(0, len(content), self.chunk_size - self.chunk_overlap):
|
248 |
+
chunk_text = content[i:i + self.chunk_size]
|
249 |
+
|
250 |
+
if not chunk_text.strip():
|
251 |
+
continue
|
252 |
+
|
253 |
+
chunk = Chunk(
|
254 |
+
id=self._generate_chunk_id(document_id, len(chunks)),
|
255 |
+
document_id=document_id,
|
256 |
+
content=chunk_text.strip(),
|
257 |
+
chunk_index=len(chunks),
|
258 |
+
start_pos=i,
|
259 |
+
end_pos=min(i + self.chunk_size, len(content)),
|
260 |
+
metadata={
|
261 |
+
"chunk_method": "fixed",
|
262 |
+
"original_length": len(chunk_text)
|
263 |
+
}
|
264 |
+
)
|
265 |
+
chunks.append(chunk)
|
266 |
+
|
267 |
+
return chunks
|
268 |
+
|
269 |
+
def _generate_chunk_id(self, document_id: str, chunk_index: int) -> str:
|
270 |
+
"""Generate a unique chunk ID"""
|
271 |
+
return f"{document_id}_chunk_{chunk_index}"
|
272 |
+
|
273 |
+
def optimize_chunks_for_embedding(self, chunks: List[Chunk]) -> List[Chunk]:
|
274 |
+
"""Optimize chunks for better embedding generation"""
|
275 |
+
optimized_chunks = []
|
276 |
+
|
277 |
+
for chunk in chunks:
|
278 |
+
# Clean the content for embedding
|
279 |
+
clean_content = self.preprocessor.prepare_for_embedding(chunk.content)
|
280 |
+
|
281 |
+
# Skip very short chunks
|
282 |
+
if len(clean_content.split()) < 5:
|
283 |
+
continue
|
284 |
+
|
285 |
+
# Update chunk with optimized content
|
286 |
+
optimized_chunk = Chunk(
|
287 |
+
id=chunk.id,
|
288 |
+
document_id=chunk.document_id,
|
289 |
+
content=clean_content,
|
290 |
+
chunk_index=chunk.chunk_index,
|
291 |
+
start_pos=chunk.start_pos,
|
292 |
+
end_pos=chunk.end_pos,
|
293 |
+
metadata={
|
294 |
+
**chunk.metadata,
|
295 |
+
"optimized_for_embedding": True,
|
296 |
+
"original_content_length": len(chunk.content),
|
297 |
+
"optimized_content_length": len(clean_content)
|
298 |
+
}
|
299 |
+
)
|
300 |
+
optimized_chunks.append(optimized_chunk)
|
301 |
+
|
302 |
+
return optimized_chunks
|
core/document_parser.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import tempfile
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional, Dict, Any
|
6 |
+
import asyncio
|
7 |
+
|
8 |
+
# Document processing libraries
|
9 |
+
import PyPDF2
|
10 |
+
from docx import Document as DocxDocument
|
11 |
+
from PIL import Image
|
12 |
+
import pytesseract
|
13 |
+
|
14 |
+
from .models import Document, DocumentType
|
15 |
+
import config
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class DocumentParser:
|
20 |
+
def __init__(self):
|
21 |
+
self.config = config.config
|
22 |
+
|
23 |
+
async def parse_document(self, file_path: str, filename: str) -> Document:
|
24 |
+
"""Parse a document and extract its content"""
|
25 |
+
try:
|
26 |
+
file_ext = Path(filename).suffix.lower()
|
27 |
+
file_size = os.path.getsize(file_path)
|
28 |
+
|
29 |
+
# Determine document type and parse accordingly
|
30 |
+
if file_ext == '.pdf':
|
31 |
+
content = await self._parse_pdf(file_path)
|
32 |
+
doc_type = DocumentType.PDF
|
33 |
+
elif file_ext == '.txt':
|
34 |
+
content = await self._parse_text(file_path)
|
35 |
+
doc_type = DocumentType.TEXT
|
36 |
+
elif file_ext == '.docx':
|
37 |
+
content = await self._parse_docx(file_path)
|
38 |
+
doc_type = DocumentType.DOCX
|
39 |
+
elif file_ext in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']:
|
40 |
+
content = await self._parse_image(file_path)
|
41 |
+
doc_type = DocumentType.IMAGE
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unsupported file type: {file_ext}")
|
44 |
+
|
45 |
+
# Create document object
|
46 |
+
document = Document(
|
47 |
+
id=self._generate_document_id(),
|
48 |
+
filename=filename,
|
49 |
+
content=content,
|
50 |
+
doc_type=doc_type,
|
51 |
+
file_size=file_size,
|
52 |
+
metadata={
|
53 |
+
"file_extension": file_ext,
|
54 |
+
"content_length": len(content),
|
55 |
+
"word_count": len(content.split()) if content else 0
|
56 |
+
}
|
57 |
+
)
|
58 |
+
|
59 |
+
logger.info(f"Successfully parsed document: {filename}")
|
60 |
+
return document
|
61 |
+
|
62 |
+
except Exception as e:
|
63 |
+
logger.error(f"Error parsing document {filename}: {str(e)}")
|
64 |
+
raise
|
65 |
+
|
66 |
+
async def _parse_pdf(self, file_path: str) -> str:
|
67 |
+
"""Extract text from PDF file"""
|
68 |
+
try:
|
69 |
+
content = ""
|
70 |
+
with open(file_path, 'rb') as file:
|
71 |
+
pdf_reader = PyPDF2.PdfReader(file)
|
72 |
+
for page_num, page in enumerate(pdf_reader.pages):
|
73 |
+
try:
|
74 |
+
page_text = page.extract_text()
|
75 |
+
if page_text.strip():
|
76 |
+
content += f"\n--- Page {page_num + 1} ---\n"
|
77 |
+
content += page_text + "\n"
|
78 |
+
except Exception as e:
|
79 |
+
logger.warning(f"Error extracting text from page {page_num + 1}: {str(e)}")
|
80 |
+
continue
|
81 |
+
|
82 |
+
return content.strip()
|
83 |
+
except Exception as e:
|
84 |
+
logger.error(f"Error parsing PDF: {str(e)}")
|
85 |
+
raise
|
86 |
+
|
87 |
+
async def _parse_text(self, file_path: str) -> str:
|
88 |
+
"""Read plain text file"""
|
89 |
+
try:
|
90 |
+
with open(file_path, 'r', encoding='utf-8', errors='ignore') as file:
|
91 |
+
content = file.read()
|
92 |
+
return content.strip()
|
93 |
+
except Exception as e:
|
94 |
+
logger.error(f"Error parsing text file: {str(e)}")
|
95 |
+
raise
|
96 |
+
|
97 |
+
async def _parse_docx(self, file_path: str) -> str:
|
98 |
+
"""Extract text from DOCX file"""
|
99 |
+
try:
|
100 |
+
doc = DocxDocument(file_path)
|
101 |
+
content = ""
|
102 |
+
|
103 |
+
for paragraph in doc.paragraphs:
|
104 |
+
if paragraph.text.strip():
|
105 |
+
content += paragraph.text + "\n"
|
106 |
+
|
107 |
+
# Extract text from tables
|
108 |
+
for table in doc.tables:
|
109 |
+
for row in table.rows:
|
110 |
+
row_text = []
|
111 |
+
for cell in row.cells:
|
112 |
+
if cell.text.strip():
|
113 |
+
row_text.append(cell.text.strip())
|
114 |
+
if row_text:
|
115 |
+
content += " | ".join(row_text) + "\n"
|
116 |
+
|
117 |
+
return content.strip()
|
118 |
+
except Exception as e:
|
119 |
+
logger.error(f"Error parsing DOCX file: {str(e)}")
|
120 |
+
raise
|
121 |
+
|
122 |
+
async def _parse_image(self, file_path: str) -> str:
|
123 |
+
"""Extract text from image using OCR"""
|
124 |
+
try:
|
125 |
+
# First try with OCR service if available
|
126 |
+
if hasattr(self, 'ocr_service') and self.ocr_service:
|
127 |
+
logger.info(f"Using OCR service for image: {file_path}")
|
128 |
+
text = await self.ocr_service.extract_text_from_image(file_path)
|
129 |
+
if text:
|
130 |
+
return text
|
131 |
+
|
132 |
+
# Fallback to direct pytesseract
|
133 |
+
logger.info(f"Using direct pytesseract for image: {file_path}")
|
134 |
+
image = Image.open(file_path)
|
135 |
+
|
136 |
+
# Perform OCR
|
137 |
+
content = pytesseract.image_to_string(
|
138 |
+
image,
|
139 |
+
lang=self.config.OCR_LANGUAGE,
|
140 |
+
config='--psm 6' # Assume a single uniform block of text
|
141 |
+
)
|
142 |
+
|
143 |
+
return content.strip()
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error performing OCR on image: {str(e)}")
|
146 |
+
# Return empty string if OCR fails
|
147 |
+
return ""
|
148 |
+
|
149 |
+
def _generate_document_id(self) -> str:
|
150 |
+
"""Generate a unique document ID"""
|
151 |
+
import uuid
|
152 |
+
return str(uuid.uuid4())
|
153 |
+
|
154 |
+
async def extract_metadata(self, file_path: str, content: str) -> Dict[str, Any]:
|
155 |
+
"""Extract additional metadata from the document"""
|
156 |
+
try:
|
157 |
+
metadata = {}
|
158 |
+
|
159 |
+
# Basic statistics
|
160 |
+
metadata["content_length"] = len(content)
|
161 |
+
metadata["word_count"] = len(content.split()) if content else 0
|
162 |
+
metadata["line_count"] = len(content.splitlines()) if content else 0
|
163 |
+
|
164 |
+
# File information
|
165 |
+
file_stat = os.stat(file_path)
|
166 |
+
metadata["file_size"] = file_stat.st_size
|
167 |
+
metadata["created_time"] = file_stat.st_ctime
|
168 |
+
metadata["modified_time"] = file_stat.st_mtime
|
169 |
+
|
170 |
+
# Content analysis
|
171 |
+
if content:
|
172 |
+
# Language detection (simple heuristic)
|
173 |
+
metadata["estimated_language"] = self._detect_language(content)
|
174 |
+
|
175 |
+
# Reading time estimation (average 200 words per minute)
|
176 |
+
metadata["estimated_reading_time_minutes"] = max(1, metadata["word_count"] // 200)
|
177 |
+
|
178 |
+
return metadata
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Error extracting metadata: {str(e)}")
|
181 |
+
return {}
|
182 |
+
|
183 |
+
def _detect_language(self, content: str) -> str:
|
184 |
+
"""Simple language detection based on character patterns"""
|
185 |
+
# This is a very basic implementation
|
186 |
+
# In production, you might want to use a proper language detection library
|
187 |
+
if not content:
|
188 |
+
return "unknown"
|
189 |
+
|
190 |
+
# Count common English words
|
191 |
+
english_words = ["the", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with", "by", "from", "as", "is", "was", "are", "were", "be", "been", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "can", "this", "that", "these", "those"]
|
192 |
+
|
193 |
+
words = content.lower().split()
|
194 |
+
english_count = sum(1 for word in words if word in english_words)
|
195 |
+
|
196 |
+
if len(words) > 0 and english_count / len(words) > 0.1:
|
197 |
+
return "en"
|
198 |
+
else:
|
199 |
+
return "unknown"
|
core/models.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel, Field
|
2 |
+
from typing import List, Optional, Dict, Any
|
3 |
+
from datetime import datetime
|
4 |
+
from enum import Enum
|
5 |
+
|
6 |
+
class DocumentType(str, Enum):
|
7 |
+
PDF = "pdf"
|
8 |
+
TEXT = "txt"
|
9 |
+
DOCX = "docx"
|
10 |
+
IMAGE = "image"
|
11 |
+
HTML = "html"
|
12 |
+
|
13 |
+
class ProcessingStatus(str, Enum):
|
14 |
+
PENDING = "pending"
|
15 |
+
PROCESSING = "processing"
|
16 |
+
COMPLETED = "completed"
|
17 |
+
FAILED = "failed"
|
18 |
+
|
19 |
+
class Document(BaseModel):
|
20 |
+
id: str = Field(..., description="Unique document identifier")
|
21 |
+
filename: str = Field(..., description="Original filename")
|
22 |
+
content: str = Field(..., description="Extracted text content")
|
23 |
+
doc_type: DocumentType = Field(..., description="Document type")
|
24 |
+
file_size: int = Field(..., description="File size in bytes")
|
25 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
26 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
27 |
+
tags: List[str] = Field(default_factory=list)
|
28 |
+
summary: Optional[str] = None
|
29 |
+
category: Optional[str] = None
|
30 |
+
language: Optional[str] = None
|
31 |
+
|
32 |
+
def to_dict(self) -> Dict[str, Any]:
|
33 |
+
return {
|
34 |
+
"id": self.id,
|
35 |
+
"filename": self.filename,
|
36 |
+
"content": self.content[:500] + "..." if len(self.content) > 500 else self.content,
|
37 |
+
"doc_type": self.doc_type,
|
38 |
+
"file_size": self.file_size,
|
39 |
+
"created_at": self.created_at.isoformat(),
|
40 |
+
"metadata": self.metadata,
|
41 |
+
"tags": self.tags,
|
42 |
+
"summary": self.summary,
|
43 |
+
"category": self.category,
|
44 |
+
"language": self.language
|
45 |
+
}
|
46 |
+
|
47 |
+
class Chunk(BaseModel):
|
48 |
+
id: str = Field(..., description="Unique chunk identifier")
|
49 |
+
document_id: str = Field(..., description="Parent document ID")
|
50 |
+
content: str = Field(..., description="Chunk text content")
|
51 |
+
chunk_index: int = Field(..., description="Position in document")
|
52 |
+
start_pos: int = Field(..., description="Start position in original document")
|
53 |
+
end_pos: int = Field(..., description="End position in original document")
|
54 |
+
embedding: Optional[List[float]] = None
|
55 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
56 |
+
|
57 |
+
class SearchResult(BaseModel):
|
58 |
+
chunk_id: str = Field(..., description="Matching chunk ID")
|
59 |
+
document_id: str = Field(..., description="Source document ID")
|
60 |
+
content: str = Field(..., description="Matching content")
|
61 |
+
score: float = Field(..., description="Similarity score")
|
62 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
63 |
+
|
64 |
+
def to_dict(self) -> Dict[str, Any]:
|
65 |
+
return {
|
66 |
+
"chunk_id": self.chunk_id,
|
67 |
+
"document_id": self.document_id,
|
68 |
+
"content": self.content,
|
69 |
+
"score": self.score,
|
70 |
+
"metadata": self.metadata
|
71 |
+
}
|
72 |
+
|
73 |
+
class ProcessingTask(BaseModel):
|
74 |
+
task_id: str = Field(..., description="Unique task identifier")
|
75 |
+
document_id: Optional[str] = None
|
76 |
+
status: ProcessingStatus = ProcessingStatus.PENDING
|
77 |
+
progress: float = Field(default=0.0, ge=0.0, le=100.0)
|
78 |
+
message: Optional[str] = None
|
79 |
+
error: Optional[str] = None
|
80 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
81 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
82 |
+
|
83 |
+
class SummaryRequest(BaseModel):
|
84 |
+
content: Optional[str] = None
|
85 |
+
document_id: Optional[str] = None
|
86 |
+
style: str = Field(default="concise", description="Summary style")
|
87 |
+
max_length: Optional[int] = None
|
88 |
+
|
89 |
+
class TagGenerationRequest(BaseModel):
|
90 |
+
content: Optional[str] = None
|
91 |
+
document_id: Optional[str] = None
|
92 |
+
max_tags: int = Field(default=5, ge=1, le=20)
|
93 |
+
|
94 |
+
class QuestionAnswerRequest(BaseModel):
|
95 |
+
question: str = Field(..., description="Question to answer")
|
96 |
+
context_filter: Optional[Dict[str, Any]] = None
|
97 |
+
max_context_length: int = Field(default=2000)
|
98 |
+
|
99 |
+
class CategorizationRequest(BaseModel):
|
100 |
+
content: Optional[str] = None
|
101 |
+
document_id: Optional[str] = None
|
102 |
+
categories: Optional[List[str]] = None
|
core/text_preprocessor.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import logging
|
3 |
+
from typing import List, Optional
|
4 |
+
import unicodedata
|
5 |
+
|
6 |
+
logger = logging.getLogger(__name__)
|
7 |
+
|
8 |
+
class TextPreprocessor:
|
9 |
+
def __init__(self):
|
10 |
+
# Common stop words for basic filtering
|
11 |
+
self.stop_words = {
|
12 |
+
'en': set([
|
13 |
+
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
|
14 |
+
'of', 'with', 'by', 'from', 'up', 'about', 'into', 'through', 'during',
|
15 |
+
'before', 'after', 'above', 'below', 'between', 'among', 'throughout',
|
16 |
+
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had',
|
17 |
+
'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might',
|
18 |
+
'must', 'shall', 'can', 'this', 'that', 'these', 'those', 'i', 'me',
|
19 |
+
'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you', 'your', 'yours'
|
20 |
+
])
|
21 |
+
}
|
22 |
+
|
23 |
+
def clean_text(self, text: str, aggressive: bool = False) -> str:
|
24 |
+
"""Clean and normalize text"""
|
25 |
+
if not text:
|
26 |
+
return ""
|
27 |
+
|
28 |
+
try:
|
29 |
+
# Normalize unicode characters
|
30 |
+
text = unicodedata.normalize('NFKD', text)
|
31 |
+
|
32 |
+
# Remove excessive whitespace
|
33 |
+
text = re.sub(r'\s+', ' ', text)
|
34 |
+
|
35 |
+
# Remove or replace special characters
|
36 |
+
if aggressive:
|
37 |
+
# More aggressive cleaning for embedding
|
38 |
+
text = re.sub(r'[^\w\s\-.,!?;:]', ' ', text)
|
39 |
+
text = re.sub(r'[.,!?;:]+', '.', text)
|
40 |
+
else:
|
41 |
+
# Basic cleaning for readability
|
42 |
+
text = re.sub(r'[^\w\s\-.,!?;:()\[\]{}"\']', ' ', text)
|
43 |
+
|
44 |
+
# Remove excessive punctuation
|
45 |
+
text = re.sub(r'\.{2,}', '.', text)
|
46 |
+
text = re.sub(r'[!?]{2,}', '!', text)
|
47 |
+
|
48 |
+
# Clean up whitespace again
|
49 |
+
text = re.sub(r'\s+', ' ', text)
|
50 |
+
|
51 |
+
# Remove leading/trailing whitespace
|
52 |
+
text = text.strip()
|
53 |
+
|
54 |
+
return text
|
55 |
+
except Exception as e:
|
56 |
+
logger.error(f"Error cleaning text: {str(e)}")
|
57 |
+
return text
|
58 |
+
|
59 |
+
def extract_sentences(self, text: str) -> List[str]:
|
60 |
+
"""Extract sentences from text"""
|
61 |
+
if not text:
|
62 |
+
return []
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Simple sentence splitting
|
66 |
+
sentences = re.split(r'[.!?]+', text)
|
67 |
+
|
68 |
+
# Clean and filter sentences
|
69 |
+
clean_sentences = []
|
70 |
+
for sentence in sentences:
|
71 |
+
sentence = sentence.strip()
|
72 |
+
if len(sentence) > 10: # Minimum sentence length
|
73 |
+
clean_sentences.append(sentence)
|
74 |
+
|
75 |
+
return clean_sentences
|
76 |
+
except Exception as e:
|
77 |
+
logger.error(f"Error extracting sentences: {str(e)}")
|
78 |
+
return [text]
|
79 |
+
|
80 |
+
def extract_keywords(self, text: str, language: str = 'en', max_keywords: int = 20) -> List[str]:
|
81 |
+
"""Extract potential keywords from text"""
|
82 |
+
if not text:
|
83 |
+
return []
|
84 |
+
|
85 |
+
try:
|
86 |
+
# Convert to lowercase and split into words
|
87 |
+
words = re.findall(r'\b[a-zA-Z]{3,}\b', text.lower())
|
88 |
+
|
89 |
+
# Remove stop words
|
90 |
+
stop_words = self.stop_words.get(language, set())
|
91 |
+
keywords = [word for word in words if word not in stop_words]
|
92 |
+
|
93 |
+
# Count word frequency
|
94 |
+
word_freq = {}
|
95 |
+
for word in keywords:
|
96 |
+
word_freq[word] = word_freq.get(word, 0) + 1
|
97 |
+
|
98 |
+
# Sort by frequency and return top keywords
|
99 |
+
sorted_keywords = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
100 |
+
|
101 |
+
return [word for word, freq in sorted_keywords[:max_keywords]]
|
102 |
+
except Exception as e:
|
103 |
+
logger.error(f"Error extracting keywords: {str(e)}")
|
104 |
+
return []
|
105 |
+
|
106 |
+
def prepare_for_embedding(self, text: str) -> str:
|
107 |
+
"""Prepare text specifically for embedding generation"""
|
108 |
+
if not text:
|
109 |
+
return ""
|
110 |
+
|
111 |
+
try:
|
112 |
+
# Clean text aggressively for better embeddings
|
113 |
+
clean_text = self.clean_text(text, aggressive=True)
|
114 |
+
|
115 |
+
# Remove very short words
|
116 |
+
words = clean_text.split()
|
117 |
+
filtered_words = [word for word in words if len(word) >= 2]
|
118 |
+
|
119 |
+
# Rejoin and ensure reasonable length
|
120 |
+
result = ' '.join(filtered_words)
|
121 |
+
|
122 |
+
# Truncate if too long (most embedding models have token limits)
|
123 |
+
if len(result) > 5000: # Rough character limit
|
124 |
+
result = result[:5000] + "..."
|
125 |
+
|
126 |
+
return result
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Error preparing text for embedding: {str(e)}")
|
129 |
+
return text
|
130 |
+
|
131 |
+
def extract_metadata_from_text(self, text: str) -> dict:
|
132 |
+
"""Extract metadata from text content"""
|
133 |
+
if not text:
|
134 |
+
return {}
|
135 |
+
|
136 |
+
try:
|
137 |
+
metadata = {}
|
138 |
+
|
139 |
+
# Basic statistics
|
140 |
+
metadata['character_count'] = len(text)
|
141 |
+
metadata['word_count'] = len(text.split())
|
142 |
+
metadata['sentence_count'] = len(self.extract_sentences(text))
|
143 |
+
metadata['paragraph_count'] = len([p for p in text.split('\n\n') if p.strip()])
|
144 |
+
|
145 |
+
# Content characteristics
|
146 |
+
metadata['avg_word_length'] = sum(len(word) for word in text.split()) / max(1, len(text.split()))
|
147 |
+
metadata['avg_sentence_length'] = metadata['word_count'] / max(1, metadata['sentence_count'])
|
148 |
+
|
149 |
+
# Special content detection
|
150 |
+
metadata['has_urls'] = bool(re.search(r'https?://\S+', text))
|
151 |
+
metadata['has_emails'] = bool(re.search(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', text))
|
152 |
+
metadata['has_phone_numbers'] = bool(re.search(r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', text))
|
153 |
+
metadata['has_dates'] = bool(re.search(r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b', text))
|
154 |
+
metadata['has_numbers'] = bool(re.search(r'\b\d+\b', text))
|
155 |
+
|
156 |
+
# Language indicators
|
157 |
+
metadata['punctuation_density'] = len(re.findall(r'[.,!?;:]', text)) / max(1, len(text))
|
158 |
+
metadata['caps_ratio'] = len(re.findall(r'[A-Z]', text)) / max(1, len(text))
|
159 |
+
|
160 |
+
return metadata
|
161 |
+
except Exception as e:
|
162 |
+
logger.error(f"Error extracting text metadata: {str(e)}")
|
163 |
+
return {}
|
164 |
+
|
165 |
+
def normalize_for_search(self, text: str) -> str:
|
166 |
+
"""Normalize text for search queries"""
|
167 |
+
if not text:
|
168 |
+
return ""
|
169 |
+
|
170 |
+
try:
|
171 |
+
# Convert to lowercase
|
172 |
+
text = text.lower()
|
173 |
+
|
174 |
+
# Remove special characters but keep spaces
|
175 |
+
text = re.sub(r'[^\w\s]', ' ', text)
|
176 |
+
|
177 |
+
# Normalize whitespace
|
178 |
+
text = re.sub(r'\s+', ' ', text)
|
179 |
+
|
180 |
+
# Strip leading/trailing whitespace
|
181 |
+
text = text.strip()
|
182 |
+
|
183 |
+
return text
|
184 |
+
except Exception as e:
|
185 |
+
logger.error(f"Error normalizing text for search: {str(e)}")
|
186 |
+
return text
|
mcp_server.py
CHANGED
@@ -1,108 +1,203 @@
|
|
1 |
-
|
2 |
-
import json
|
3 |
-
from typing import Dict, List, Any
|
4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
# Set up logging
|
7 |
logging.basicConfig(level=logging.INFO)
|
8 |
logger = logging.getLogger(__name__)
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@mcp.tool()
|
14 |
-
async def
|
15 |
"""
|
16 |
-
Process
|
|
|
17 |
"""
|
|
|
18 |
try:
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
return result
|
22 |
except Exception as e:
|
23 |
-
logger.error(f"Error
|
24 |
-
return {"error": str(e)}
|
25 |
|
26 |
@mcp.tool()
|
27 |
-
async def
|
28 |
"""
|
29 |
-
|
|
|
30 |
"""
|
|
|
31 |
try:
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
35 |
except Exception as e:
|
36 |
-
logger.error(f"Error
|
37 |
-
return {"error": str(e)}
|
38 |
|
39 |
@mcp.tool()
|
40 |
-
async def
|
|
|
|
|
|
|
|
|
41 |
"""
|
42 |
-
|
|
|
43 |
"""
|
|
|
44 |
try:
|
45 |
-
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
except Exception as e:
|
49 |
-
logger.error(f"Error
|
50 |
-
return
|
51 |
|
52 |
@mcp.tool()
|
53 |
-
async def
|
|
|
|
|
|
|
|
|
54 |
"""
|
55 |
-
|
|
|
56 |
"""
|
|
|
57 |
try:
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
except Exception as e:
|
62 |
-
logger.error(f"Error
|
63 |
-
return {"error": str(e)}
|
64 |
|
65 |
@mcp.tool()
|
66 |
-
async def
|
67 |
"""
|
68 |
-
|
|
|
69 |
"""
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
"Document summarization"
|
80 |
-
],
|
81 |
-
"tools": [
|
82 |
-
{
|
83 |
-
"name": "process_file",
|
84 |
-
"description": "Process local files and extract content"
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"name": "process_url",
|
88 |
-
"description": "Fetch and process web content"
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"name": "semantic_search",
|
92 |
-
"description": "Search across stored documents"
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"name": "get_document_summary",
|
96 |
-
"description": "Get document details"
|
97 |
-
},
|
98 |
-
{
|
99 |
-
"name": "get_server_info",
|
100 |
-
"description": "Get server information"
|
101 |
}
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
-
|
107 |
-
import asyncio
|
108 |
asyncio.run(mcp.run())
|
|
|
1 |
+
import asyncio
|
|
|
|
|
2 |
import logging
|
3 |
+
from typing import Dict, Any, List, Optional
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from mcp.server.fastmcp import FastMCP
|
7 |
+
|
8 |
+
from services.vector_store_service import VectorStoreService
|
9 |
+
from services.document_store_service import DocumentStoreService
|
10 |
+
from services.embedding_service import EmbeddingService
|
11 |
+
from services.llm_service import LLMService
|
12 |
+
from services.ocr_service import OCRService
|
13 |
+
|
14 |
+
from mcp_tools.ingestion_tool import IngestionTool
|
15 |
+
from mcp_tools.search_tool import SearchTool
|
16 |
+
from mcp_tools.generative_tool import GenerativeTool
|
17 |
|
|
|
18 |
logging.basicConfig(level=logging.INFO)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
+
logger.info("Initializing services for FastMCP...")
|
22 |
+
vector_store_service = VectorStoreService()
|
23 |
+
document_store_service = DocumentStoreService()
|
24 |
+
embedding_service_instance = EmbeddingService()
|
25 |
+
llm_service_instance = LLMService()
|
26 |
+
ocr_service_instance = OCRService()
|
27 |
+
|
28 |
+
ingestion_tool_instance = IngestionTool(
|
29 |
+
vector_store=vector_store_service,
|
30 |
+
document_store=document_store_service,
|
31 |
+
embedding_service=embedding_service_instance,
|
32 |
+
ocr_service=ocr_service_instance
|
33 |
+
)
|
34 |
+
search_tool_instance = SearchTool(
|
35 |
+
vector_store=vector_store_service,
|
36 |
+
embedding_service=embedding_service_instance,
|
37 |
+
document_store=document_store_service
|
38 |
+
)
|
39 |
+
generative_tool_instance = GenerativeTool(
|
40 |
+
llm_service=llm_service_instance,
|
41 |
+
search_tool=search_tool_instance
|
42 |
+
)
|
43 |
+
|
44 |
+
mcp = FastMCP("intelligent-content-organizer-fmcp")
|
45 |
+
logger.info("FastMCP server initialized.")
|
46 |
|
47 |
@mcp.tool()
|
48 |
+
async def ingest_document(file_path: str, file_type: Optional[str] = None) -> Dict[str, Any]:
|
49 |
"""
|
50 |
+
Process and index a document from a local file path for searching.
|
51 |
+
Automatically determines file_type if not provided.
|
52 |
"""
|
53 |
+
logger.info(f"Tool 'ingest_document' called with file_path: {file_path}, file_type: {file_type}")
|
54 |
try:
|
55 |
+
actual_file_type = file_type
|
56 |
+
if not actual_file_type:
|
57 |
+
actual_file_type = Path(file_path).suffix.lower().strip('.')
|
58 |
+
logger.info(f"Inferred file_type: {actual_file_type}")
|
59 |
+
result = await ingestion_tool_instance.process_document(file_path, actual_file_type)
|
60 |
+
logger.info(f"Ingestion result: {result}")
|
61 |
return result
|
62 |
except Exception as e:
|
63 |
+
logger.error(f"Error in 'ingest_document' tool: {str(e)}", exc_info=True)
|
64 |
+
return {"success": False, "error": str(e)}
|
65 |
|
66 |
@mcp.tool()
|
67 |
+
async def semantic_search(query: str, top_k: int = 5, filters: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
68 |
"""
|
69 |
+
Search through indexed content using natural language.
|
70 |
+
'filters' can be used to narrow down the search.
|
71 |
"""
|
72 |
+
logger.info(f"Tool 'semantic_search' called with query: {query}, top_k: {top_k}, filters: {filters}")
|
73 |
try:
|
74 |
+
results = await search_tool_instance.search(query, top_k, filters)
|
75 |
+
return {
|
76 |
+
"success": True,
|
77 |
+
"query": query,
|
78 |
+
"results": [result.to_dict() for result in results],
|
79 |
+
"total_results": len(results)
|
80 |
+
}
|
81 |
except Exception as e:
|
82 |
+
logger.error(f"Error in 'semantic_search' tool: {str(e)}", exc_info=True)
|
83 |
+
return {"success": False, "error": str(e), "results": []}
|
84 |
|
85 |
@mcp.tool()
|
86 |
+
async def summarize_content(
|
87 |
+
content: Optional[str] = None,
|
88 |
+
document_id: Optional[str] = None,
|
89 |
+
style: str = "concise"
|
90 |
+
) -> Dict[str, Any]:
|
91 |
"""
|
92 |
+
Generate a summary of provided content or a document_id.
|
93 |
+
Available styles: concise, detailed, bullet_points, executive.
|
94 |
"""
|
95 |
+
logger.info(f"Tool 'summarize_content' called. doc_id: {document_id}, style: {style}, has_content: {content is not None}")
|
96 |
try:
|
97 |
+
text_to_summarize = content
|
98 |
+
if document_id and not text_to_summarize:
|
99 |
+
doc = await document_store_service.get_document(document_id)
|
100 |
+
if not doc:
|
101 |
+
return {"success": False, "error": f"Document {document_id} not found"}
|
102 |
+
text_to_summarize = doc.content
|
103 |
+
if not text_to_summarize:
|
104 |
+
return {"success": False, "error": "No content provided for summarization"}
|
105 |
+
max_length = 10000
|
106 |
+
if len(text_to_summarize) > max_length:
|
107 |
+
logger.warning(f"Content for summarization is long ({len(text_to_summarize)} chars), truncating to {max_length}")
|
108 |
+
text_to_summarize = text_to_summarize[:max_length] + "..."
|
109 |
+
summary = await generative_tool_instance.summarize(text_to_summarize, style)
|
110 |
+
return {
|
111 |
+
"success": True,
|
112 |
+
"summary": summary,
|
113 |
+
"original_length": len(text_to_summarize),
|
114 |
+
"summary_length": len(summary),
|
115 |
+
"style": style
|
116 |
+
}
|
117 |
except Exception as e:
|
118 |
+
logger.error(f"Error in 'summarize_content' tool: {str(e)}", exc_info=True)
|
119 |
+
return {"success": False, "error": str(e)}
|
120 |
|
121 |
@mcp.tool()
|
122 |
+
async def generate_tags(
|
123 |
+
content: Optional[str] = None,
|
124 |
+
document_id: Optional[str] = None,
|
125 |
+
max_tags: int = 5
|
126 |
+
) -> Dict[str, Any]:
|
127 |
"""
|
128 |
+
Generate relevant tags for content or a document_id.
|
129 |
+
Saves tags to document metadata if document_id is provided.
|
130 |
"""
|
131 |
+
logger.info(f"Tool 'generate_tags' called. doc_id: {document_id}, max_tags: {max_tags}, has_content: {content is not None}")
|
132 |
try:
|
133 |
+
text_for_tags = content
|
134 |
+
if document_id and not text_for_tags:
|
135 |
+
doc = await document_store_service.get_document(document_id)
|
136 |
+
if not doc:
|
137 |
+
return {"success": False, "error": f"Document {document_id} not found"}
|
138 |
+
text_for_tags = doc.content
|
139 |
+
if not text_for_tags:
|
140 |
+
return {"success": False, "error": "No content provided for tag generation"}
|
141 |
+
tags = await generative_tool_instance.generate_tags(text_for_tags, max_tags)
|
142 |
+
if document_id and tags:
|
143 |
+
await document_store_service.update_document_metadata(document_id, {"tags": tags})
|
144 |
+
logger.info(f"Tags {tags} saved for document {document_id}")
|
145 |
+
return {
|
146 |
+
"success": True,
|
147 |
+
"tags": tags,
|
148 |
+
"content_length": len(text_for_tags),
|
149 |
+
"document_id": document_id
|
150 |
+
}
|
151 |
except Exception as e:
|
152 |
+
logger.error(f"Error in 'generate_tags' tool: {str(e)}", exc_info=True)
|
153 |
+
return {"success": False, "error": str(e)}
|
154 |
|
155 |
@mcp.tool()
|
156 |
+
async def answer_question(question: str, context_filter: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
157 |
"""
|
158 |
+
Answer questions using RAG (Retrieval Augmented Generation) over indexed content.
|
159 |
+
'context_filter' can be used to narrow down the context search.
|
160 |
"""
|
161 |
+
logger.info(f"Tool 'answer_question' called with question: {question}, context_filter: {context_filter}")
|
162 |
+
try:
|
163 |
+
search_results = await search_tool_instance.search(question, top_k=5, filters=context_filter)
|
164 |
+
if not search_results:
|
165 |
+
return {
|
166 |
+
"success": False,
|
167 |
+
"error": "No relevant context found. Please upload relevant documents.",
|
168 |
+
"question": question,
|
169 |
+
"answer": "I could not find enough information in the documents to answer your question."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
}
|
171 |
+
answer = await generative_tool_instance.answer_question(question, search_results)
|
172 |
+
return {
|
173 |
+
"success": True,
|
174 |
+
"question": question,
|
175 |
+
"answer": answer,
|
176 |
+
"sources": [result.to_dict() for result in search_results],
|
177 |
+
"confidence": "high" if len(search_results) >= 3 else "medium"
|
178 |
+
}
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Error in 'answer_question' tool: {str(e)}", exc_info=True)
|
181 |
+
return {"success": False, "error": str(e)}
|
182 |
+
|
183 |
+
@mcp.tool()
|
184 |
+
async def list_documents_for_ui(limit: int = 100, offset: int = 0) -> Dict[str, Any]:
|
185 |
+
"""
|
186 |
+
(UI Helper) List documents from the document store.
|
187 |
+
Not a standard processing tool, but useful for UI population.
|
188 |
+
"""
|
189 |
+
logger.info(f"Tool 'list_documents_for_ui' called with limit: {limit}, offset: {offset}")
|
190 |
+
try:
|
191 |
+
documents = await document_store_service.list_documents(limit, offset)
|
192 |
+
return {
|
193 |
+
"success": True,
|
194 |
+
"documents": [doc.to_dict() for doc in documents],
|
195 |
+
"total": len(documents)
|
196 |
+
}
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error in 'list_documents_for_ui' tool: {str(e)}", exc_info=True)
|
199 |
+
return {"success": False, "error": str(e), "documents": []}
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
+
logger.info("Starting FastMCP server...")
|
|
|
203 |
asyncio.run(mcp.run())
|
mcp_tools.py
DELETED
@@ -1,592 +0,0 @@
|
|
1 |
-
import asyncio
|
2 |
-
import aiohttp
|
3 |
-
import chromadb
|
4 |
-
from chromadb.utils import embedding_functions
|
5 |
-
import json
|
6 |
-
import logging
|
7 |
-
from typing import Dict, List, Any, Optional
|
8 |
-
from datetime import datetime
|
9 |
-
import hashlib
|
10 |
-
from pathlib import Path
|
11 |
-
import requests
|
12 |
-
|
13 |
-
# Document processing libraries (all free)
|
14 |
-
import PyPDF2
|
15 |
-
import docx
|
16 |
-
from bs4 import BeautifulSoup
|
17 |
-
import pandas as pd
|
18 |
-
import markdown
|
19 |
-
import xml.etree.ElementTree as ET
|
20 |
-
from newspaper import Article
|
21 |
-
import trafilatura
|
22 |
-
from duckduckgo_search import DDGS
|
23 |
-
|
24 |
-
# AI libraries
|
25 |
-
from config import Config
|
26 |
-
from mistralai.client import MistralClient
|
27 |
-
import anthropic
|
28 |
-
|
29 |
-
# Set up logging
|
30 |
-
logger = logging.getLogger(__name__)
|
31 |
-
|
32 |
-
# Initialize AI clients
|
33 |
-
mistral_client = MistralClient(api_key=Config.MISTRAL_API_KEY) if Config.MISTRAL_API_KEY else None
|
34 |
-
anthropic_client = anthropic.Anthropic(api_key=Config.ANTHROPIC_API_KEY) if Config.ANTHROPIC_API_KEY else None
|
35 |
-
|
36 |
-
# Initialize ChromaDB
|
37 |
-
chroma_client = chromadb.PersistentClient(path=Config.CHROMA_DB_PATH)
|
38 |
-
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
39 |
-
model_name=Config.EMBEDDING_MODEL
|
40 |
-
)
|
41 |
-
|
42 |
-
# Get or create collection
|
43 |
-
try:
|
44 |
-
collection = chroma_client.get_collection(
|
45 |
-
name=Config.CHROMA_COLLECTION_NAME,
|
46 |
-
embedding_function=embedding_function
|
47 |
-
)
|
48 |
-
except:
|
49 |
-
collection = chroma_client.create_collection(
|
50 |
-
name=Config.CHROMA_COLLECTION_NAME,
|
51 |
-
embedding_function=embedding_function
|
52 |
-
)
|
53 |
-
|
54 |
-
class DocumentProcessor:
|
55 |
-
"""Free document processing without Unstructured API"""
|
56 |
-
|
57 |
-
@staticmethod
|
58 |
-
def extract_text_from_pdf(file_path: str) -> str:
|
59 |
-
"""Extract text from PDF files"""
|
60 |
-
text = ""
|
61 |
-
try:
|
62 |
-
with open(file_path, 'rb') as file:
|
63 |
-
pdf_reader = PyPDF2.PdfReader(file)
|
64 |
-
for page_num in range(len(pdf_reader.pages)):
|
65 |
-
page = pdf_reader.pages[page_num]
|
66 |
-
text += page.extract_text() + "\n"
|
67 |
-
except Exception as e:
|
68 |
-
logger.error(f"Error reading PDF: {e}")
|
69 |
-
return text
|
70 |
-
|
71 |
-
@staticmethod
|
72 |
-
def extract_text_from_docx(file_path: str) -> str:
|
73 |
-
"""Extract text from DOCX files"""
|
74 |
-
try:
|
75 |
-
doc = docx.Document(file_path)
|
76 |
-
text = "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
77 |
-
return text
|
78 |
-
except Exception as e:
|
79 |
-
logger.error(f"Error reading DOCX: {e}")
|
80 |
-
return ""
|
81 |
-
|
82 |
-
@staticmethod
|
83 |
-
def extract_text_from_html(file_path: str) -> str:
|
84 |
-
"""Extract text from HTML files"""
|
85 |
-
try:
|
86 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
87 |
-
soup = BeautifulSoup(file.read(), 'html.parser')
|
88 |
-
# Remove script and style elements
|
89 |
-
for script in soup(["script", "style"]):
|
90 |
-
script.extract()
|
91 |
-
text = soup.get_text()
|
92 |
-
lines = (line.strip() for line in text.splitlines())
|
93 |
-
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
94 |
-
text = '\n'.join(chunk for chunk in chunks if chunk)
|
95 |
-
return text
|
96 |
-
except Exception as e:
|
97 |
-
logger.error(f"Error reading HTML: {e}")
|
98 |
-
return ""
|
99 |
-
|
100 |
-
@staticmethod
|
101 |
-
def extract_text_from_txt(file_path: str) -> str:
|
102 |
-
"""Extract text from TXT files"""
|
103 |
-
try:
|
104 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
105 |
-
return file.read()
|
106 |
-
except Exception as e:
|
107 |
-
logger.error(f"Error reading TXT: {e}")
|
108 |
-
return ""
|
109 |
-
|
110 |
-
@staticmethod
|
111 |
-
def extract_text_from_csv(file_path: str) -> str:
|
112 |
-
"""Extract text from CSV files"""
|
113 |
-
try:
|
114 |
-
df = pd.read_csv(file_path)
|
115 |
-
return df.to_string()
|
116 |
-
except Exception as e:
|
117 |
-
logger.error(f"Error reading CSV: {e}")
|
118 |
-
return ""
|
119 |
-
|
120 |
-
@staticmethod
|
121 |
-
def extract_text_from_json(file_path: str) -> str:
|
122 |
-
"""Extract text from JSON files"""
|
123 |
-
try:
|
124 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
125 |
-
data = json.load(file)
|
126 |
-
return json.dumps(data, indent=2)
|
127 |
-
except Exception as e:
|
128 |
-
logger.error(f"Error reading JSON: {e}")
|
129 |
-
return ""
|
130 |
-
|
131 |
-
@staticmethod
|
132 |
-
def extract_text_from_markdown(file_path: str) -> str:
|
133 |
-
"""Extract text from Markdown files"""
|
134 |
-
try:
|
135 |
-
with open(file_path, 'r', encoding='utf-8') as file:
|
136 |
-
md_text = file.read()
|
137 |
-
html = markdown.markdown(md_text)
|
138 |
-
soup = BeautifulSoup(html, 'html.parser')
|
139 |
-
return soup.get_text()
|
140 |
-
except Exception as e:
|
141 |
-
logger.error(f"Error reading Markdown: {e}")
|
142 |
-
return ""
|
143 |
-
|
144 |
-
@staticmethod
|
145 |
-
def extract_text_from_xml(file_path: str) -> str:
|
146 |
-
"""Extract text from XML files"""
|
147 |
-
try:
|
148 |
-
tree = ET.parse(file_path)
|
149 |
-
root = tree.getroot()
|
150 |
-
|
151 |
-
def extract_text(element):
|
152 |
-
text = element.text or ""
|
153 |
-
for child in element:
|
154 |
-
text += " " + extract_text(child)
|
155 |
-
return text.strip()
|
156 |
-
|
157 |
-
return extract_text(root)
|
158 |
-
except Exception as e:
|
159 |
-
logger.error(f"Error reading XML: {e}")
|
160 |
-
return ""
|
161 |
-
|
162 |
-
@classmethod
|
163 |
-
def extract_text(cls, file_path: str) -> str:
|
164 |
-
"""Extract text from any supported file type"""
|
165 |
-
path = Path(file_path)
|
166 |
-
extension = path.suffix.lower()
|
167 |
-
|
168 |
-
extractors = {
|
169 |
-
'.pdf': cls.extract_text_from_pdf,
|
170 |
-
'.docx': cls.extract_text_from_docx,
|
171 |
-
'.doc': cls.extract_text_from_docx,
|
172 |
-
'.html': cls.extract_text_from_html,
|
173 |
-
'.htm': cls.extract_text_from_html,
|
174 |
-
'.txt': cls.extract_text_from_txt,
|
175 |
-
'.csv': cls.extract_text_from_csv,
|
176 |
-
'.json': cls.extract_text_from_json,
|
177 |
-
'.md': cls.extract_text_from_markdown,
|
178 |
-
'.xml': cls.extract_text_from_xml,
|
179 |
-
}
|
180 |
-
|
181 |
-
extractor = extractors.get(extension, cls.extract_text_from_txt)
|
182 |
-
return extractor(file_path)
|
183 |
-
|
184 |
-
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 100) -> List[str]:
|
185 |
-
"""Split text into chunks with overlap"""
|
186 |
-
chunks = []
|
187 |
-
start = 0
|
188 |
-
text_length = len(text)
|
189 |
-
|
190 |
-
while start < text_length:
|
191 |
-
end = start + chunk_size
|
192 |
-
chunk = text[start:end]
|
193 |
-
|
194 |
-
# Try to find a sentence boundary
|
195 |
-
if end < text_length:
|
196 |
-
last_period = chunk.rfind('.')
|
197 |
-
last_newline = chunk.rfind('\n')
|
198 |
-
boundary = max(last_period, last_newline)
|
199 |
-
|
200 |
-
if boundary > chunk_size // 2:
|
201 |
-
chunk = text[start:start + boundary + 1]
|
202 |
-
end = start + boundary + 1
|
203 |
-
|
204 |
-
chunks.append(chunk.strip())
|
205 |
-
start = end - overlap
|
206 |
-
|
207 |
-
return chunks
|
208 |
-
|
209 |
-
async def fetch_web_content_free(url: str) -> Optional[str]:
|
210 |
-
"""Fetch content from URL using multiple free methods"""
|
211 |
-
|
212 |
-
# Method 1: Try newspaper3k (best for articles)
|
213 |
-
try:
|
214 |
-
article = Article(url)
|
215 |
-
article.download()
|
216 |
-
article.parse()
|
217 |
-
|
218 |
-
content = f"{article.title}\n\n{article.text}"
|
219 |
-
if len(content) > 100: # Valid content
|
220 |
-
return content
|
221 |
-
except Exception as e:
|
222 |
-
logger.debug(f"Newspaper failed: {e}")
|
223 |
-
|
224 |
-
# Method 2: Try trafilatura (great for web scraping)
|
225 |
-
try:
|
226 |
-
downloaded = trafilatura.fetch_url(url)
|
227 |
-
content = trafilatura.extract(downloaded)
|
228 |
-
if content and len(content) > 100:
|
229 |
-
return content
|
230 |
-
except Exception as e:
|
231 |
-
logger.debug(f"Trafilatura failed: {e}")
|
232 |
-
|
233 |
-
# Method 3: Basic BeautifulSoup scraping
|
234 |
-
try:
|
235 |
-
headers = {
|
236 |
-
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
237 |
-
}
|
238 |
-
response = requests.get(url, headers=headers, timeout=10)
|
239 |
-
|
240 |
-
if response.status_code == 200:
|
241 |
-
soup = BeautifulSoup(response.text, 'html.parser')
|
242 |
-
|
243 |
-
# Remove unwanted elements
|
244 |
-
for element in soup(['script', 'style', 'nav', 'footer', 'header']):
|
245 |
-
element.decompose()
|
246 |
-
|
247 |
-
# Try to find main content
|
248 |
-
main_content = None
|
249 |
-
|
250 |
-
# Common content selectors
|
251 |
-
content_selectors = [
|
252 |
-
'main', 'article', '[role="main"]',
|
253 |
-
'.content', '#content', '.post', '.entry-content',
|
254 |
-
'.article-body', '.story-body'
|
255 |
-
]
|
256 |
-
|
257 |
-
for selector in content_selectors:
|
258 |
-
main_content = soup.select_one(selector)
|
259 |
-
if main_content:
|
260 |
-
break
|
261 |
-
|
262 |
-
if not main_content:
|
263 |
-
main_content = soup.find('body')
|
264 |
-
|
265 |
-
if main_content:
|
266 |
-
text = main_content.get_text(separator='\n', strip=True)
|
267 |
-
|
268 |
-
# Get title
|
269 |
-
title = soup.find('title')
|
270 |
-
title_text = title.get_text() if title else "No title"
|
271 |
-
|
272 |
-
return f"{title_text}\n\n{text}"
|
273 |
-
|
274 |
-
except Exception as e:
|
275 |
-
logger.error(f"BeautifulSoup failed: {e}")
|
276 |
-
|
277 |
-
return None
|
278 |
-
|
279 |
-
async def search_web_free(query: str, num_results: int = 5) -> List[Dict[str, str]]:
|
280 |
-
"""Search the web using free methods (DuckDuckGo)"""
|
281 |
-
try:
|
282 |
-
results = []
|
283 |
-
with DDGS() as ddgs:
|
284 |
-
for r in ddgs.text(query, max_results=num_results):
|
285 |
-
results.append({
|
286 |
-
'title': r.get('title', ''),
|
287 |
-
'url': r.get('link', ''),
|
288 |
-
'snippet': r.get('body', '')
|
289 |
-
})
|
290 |
-
|
291 |
-
return results
|
292 |
-
|
293 |
-
except Exception as e:
|
294 |
-
logger.error(f"Search failed: {e}")
|
295 |
-
return []
|
296 |
-
|
297 |
-
# In mcp_tools.py
|
298 |
-
|
299 |
-
async def generate_tags(content: str) -> List[str]:
|
300 |
-
"""Generate tags using Mistral AI or fallback to free method"""
|
301 |
-
try:
|
302 |
-
if mistral_client: # This is MistralClient from mistralai.client
|
303 |
-
prompt = f"""Analyze this content and generate 5-7 relevant tags.
|
304 |
-
Return only the tags as a comma-separated list.
|
305 |
-
|
306 |
-
Content: {content[:2000]}...
|
307 |
-
|
308 |
-
Tags:"""
|
309 |
-
|
310 |
-
# For mistralai==0.4.2, pass messages as a list of dicts
|
311 |
-
response = mistral_client.chat(
|
312 |
-
model=Config.MISTRAL_MODEL,
|
313 |
-
messages=[{"role": "user", "content": prompt}] # <--- CHANGE HERE
|
314 |
-
)
|
315 |
-
|
316 |
-
tags_text = response.choices[0].message.content.strip()
|
317 |
-
tags = [tag.strip() for tag in tags_text.split(",")]
|
318 |
-
return tags[:7]
|
319 |
-
else:
|
320 |
-
# Free fallback: Extract keywords using frequency analysis
|
321 |
-
return generate_tags_free(content)
|
322 |
-
|
323 |
-
except Exception as e:
|
324 |
-
logger.error(f"Error generating tags: {str(e)}")
|
325 |
-
return generate_tags_free(content)
|
326 |
-
|
327 |
-
def generate_tags_free(content: str) -> List[str]:
|
328 |
-
"""Free tag generation using keyword extraction"""
|
329 |
-
from collections import Counter
|
330 |
-
import re
|
331 |
-
|
332 |
-
# Simple keyword extraction
|
333 |
-
words = re.findall(r'\b[a-z]{4,}\b', content.lower())
|
334 |
-
|
335 |
-
# Common stop words
|
336 |
-
stop_words = {
|
337 |
-
'this', 'that', 'these', 'those', 'what', 'which', 'when', 'where',
|
338 |
-
'who', 'whom', 'whose', 'why', 'how', 'with', 'about', 'against',
|
339 |
-
'between', 'into', 'through', 'during', 'before', 'after', 'above',
|
340 |
-
'below', 'from', 'down', 'out', 'off', 'over', 'under', 'again',
|
341 |
-
'further', 'then', 'once', 'here', 'there', 'when', 'where', 'why',
|
342 |
-
'how', 'all', 'both', 'each', 'few', 'more', 'most', 'other', 'some',
|
343 |
-
'such', 'only', 'same', 'than', 'that', 'have', 'has', 'had',
|
344 |
-
'been', 'being', 'does', 'doing', 'will', 'would', 'could', 'should'
|
345 |
-
}
|
346 |
-
|
347 |
-
# Filter and count words
|
348 |
-
filtered_words = [w for w in words if w not in stop_words and len(w) > 4]
|
349 |
-
word_counts = Counter(filtered_words)
|
350 |
-
|
351 |
-
# Get top keywords
|
352 |
-
top_keywords = [word for word, _ in word_counts.most_common(7)]
|
353 |
-
|
354 |
-
return top_keywords if top_keywords else ["untagged"]
|
355 |
-
|
356 |
-
async def generate_summary(content: str) -> str:
|
357 |
-
"""Generate summary using Claude or fallback to free method"""
|
358 |
-
try:
|
359 |
-
if anthropic_client:
|
360 |
-
message = anthropic_client.messages.create(
|
361 |
-
model=Config.CLAUDE_MODEL,
|
362 |
-
max_tokens=300,
|
363 |
-
messages=[{
|
364 |
-
"role": "user",
|
365 |
-
"content": f"Summarize this content in 2-3 sentences:\n\n{content[:4000]}..."
|
366 |
-
}]
|
367 |
-
)
|
368 |
-
|
369 |
-
return message.content[0].text.strip()
|
370 |
-
else:
|
371 |
-
# Free fallback
|
372 |
-
return generate_summary_free(content)
|
373 |
-
|
374 |
-
except Exception as e:
|
375 |
-
logger.error(f"Error generating summary: {str(e)}")
|
376 |
-
return generate_summary_free(content)
|
377 |
-
|
378 |
-
def generate_summary_free(content: str) -> str:
|
379 |
-
"""Free summary generation using simple extraction"""
|
380 |
-
sentences = content.split('.')
|
381 |
-
# Take first 3 sentences
|
382 |
-
summary_sentences = sentences[:3]
|
383 |
-
summary = '. '.join(s.strip() for s in summary_sentences if s.strip())
|
384 |
-
|
385 |
-
if len(summary) > 300:
|
386 |
-
summary = summary[:297] + "..."
|
387 |
-
|
388 |
-
return summary if summary else "Content preview: " + content[:200] + "..."
|
389 |
-
|
390 |
-
async def process_local_file(file_path: str) -> Dict[str, Any]:
|
391 |
-
"""Process a local file and store it in the knowledge base"""
|
392 |
-
try:
|
393 |
-
# Validate file
|
394 |
-
path = Path(file_path)
|
395 |
-
if not path.exists():
|
396 |
-
raise FileNotFoundError(f"File not found: {file_path}")
|
397 |
-
|
398 |
-
if path.suffix.lower() not in Config.SUPPORTED_FILE_TYPES:
|
399 |
-
raise ValueError(f"Unsupported file type: {path.suffix}")
|
400 |
-
|
401 |
-
# Extract text using free methods
|
402 |
-
full_text = DocumentProcessor.extract_text(file_path)
|
403 |
-
|
404 |
-
if not full_text:
|
405 |
-
raise ValueError("No text could be extracted from the file")
|
406 |
-
|
407 |
-
# Generate document ID
|
408 |
-
doc_id = hashlib.md5(f"{path.name}_{datetime.now().isoformat()}".encode()).hexdigest()
|
409 |
-
|
410 |
-
# Generate tags
|
411 |
-
tags = await generate_tags(full_text[:3000])
|
412 |
-
|
413 |
-
# Generate summary
|
414 |
-
summary = await generate_summary(full_text[:5000])
|
415 |
-
|
416 |
-
# Chunk the text
|
417 |
-
chunks = chunk_text(full_text, chunk_size=1000, overlap=100)
|
418 |
-
chunks = chunks[:10] # Limit chunks for demo
|
419 |
-
|
420 |
-
# Store in ChromaDB
|
421 |
-
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
422 |
-
|
423 |
-
metadata = {
|
424 |
-
"source": str(path),
|
425 |
-
"file_name": path.name,
|
426 |
-
"file_type": path.suffix,
|
427 |
-
"processed_at": datetime.now().isoformat(),
|
428 |
-
"tags": ", ".join(tags),
|
429 |
-
"summary": summary,
|
430 |
-
"doc_id": doc_id
|
431 |
-
}
|
432 |
-
|
433 |
-
collection.add(
|
434 |
-
documents=chunks,
|
435 |
-
ids=chunk_ids,
|
436 |
-
metadatas=[metadata for _ in chunks]
|
437 |
-
)
|
438 |
-
|
439 |
-
return {
|
440 |
-
"success": True,
|
441 |
-
"doc_id": doc_id,
|
442 |
-
"file_name": path.name,
|
443 |
-
"tags": tags,
|
444 |
-
"summary": summary,
|
445 |
-
"chunks_processed": len(chunks),
|
446 |
-
"metadata": metadata
|
447 |
-
}
|
448 |
-
|
449 |
-
except Exception as e:
|
450 |
-
logger.error(f"Error processing file: {str(e)}")
|
451 |
-
return {
|
452 |
-
"success": False,
|
453 |
-
"error": str(e)
|
454 |
-
}
|
455 |
-
|
456 |
-
async def process_web_content(url_or_query: str) -> Dict[str, Any]:
|
457 |
-
"""Process web content from URL or search query"""
|
458 |
-
try:
|
459 |
-
# Check if it's a URL or search query
|
460 |
-
is_url = url_or_query.startswith(('http://', 'https://'))
|
461 |
-
|
462 |
-
if is_url:
|
463 |
-
content = await fetch_web_content_free(url_or_query)
|
464 |
-
source = url_or_query
|
465 |
-
else:
|
466 |
-
# It's a search query
|
467 |
-
search_results = await search_web_free(url_or_query, num_results=3)
|
468 |
-
if not search_results:
|
469 |
-
raise ValueError("No search results found")
|
470 |
-
|
471 |
-
# Process the first result
|
472 |
-
first_result = search_results[0]
|
473 |
-
content = await fetch_web_content_free(first_result['url'])
|
474 |
-
source = first_result['url']
|
475 |
-
|
476 |
-
# Add search context
|
477 |
-
content = f"Search Query: {url_or_query}\n\n{first_result['title']}\n\n{content}"
|
478 |
-
|
479 |
-
if not content:
|
480 |
-
raise ValueError("Failed to fetch content")
|
481 |
-
|
482 |
-
# Generate document ID
|
483 |
-
doc_id = hashlib.md5(f"{source}_{datetime.now().isoformat()}".encode()).hexdigest()
|
484 |
-
|
485 |
-
# Generate tags
|
486 |
-
tags = await generate_tags(content[:3000])
|
487 |
-
|
488 |
-
# Generate summary
|
489 |
-
summary = await generate_summary(content[:5000])
|
490 |
-
|
491 |
-
# Chunk the content
|
492 |
-
chunks = chunk_text(content, chunk_size=1000, overlap=100)
|
493 |
-
chunks = chunks[:10] # Limit for demo
|
494 |
-
|
495 |
-
# Store in ChromaDB
|
496 |
-
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
497 |
-
|
498 |
-
metadata = {
|
499 |
-
"source": source,
|
500 |
-
"url": source if is_url else f"Search: {url_or_query}",
|
501 |
-
"content_type": "web",
|
502 |
-
"processed_at": datetime.now().isoformat(),
|
503 |
-
"tags": ", ".join(tags),
|
504 |
-
"summary": summary,
|
505 |
-
"doc_id": doc_id
|
506 |
-
}
|
507 |
-
|
508 |
-
collection.add(
|
509 |
-
documents=chunks,
|
510 |
-
ids=chunk_ids,
|
511 |
-
metadatas=[metadata for _ in chunks]
|
512 |
-
)
|
513 |
-
|
514 |
-
return {
|
515 |
-
"success": True,
|
516 |
-
"doc_id": doc_id,
|
517 |
-
"url": source,
|
518 |
-
"tags": tags,
|
519 |
-
"summary": summary,
|
520 |
-
"chunks_processed": len(chunks),
|
521 |
-
"metadata": metadata,
|
522 |
-
"search_query": url_or_query if not is_url else None
|
523 |
-
}
|
524 |
-
|
525 |
-
except Exception as e:
|
526 |
-
logger.error(f"Error processing web content: {str(e)}")
|
527 |
-
return {
|
528 |
-
"success": False,
|
529 |
-
"error": str(e)
|
530 |
-
}
|
531 |
-
|
532 |
-
async def search_knowledge_base(query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
533 |
-
"""Perform semantic search in the knowledge base"""
|
534 |
-
try:
|
535 |
-
results = collection.query(
|
536 |
-
query_texts=[query],
|
537 |
-
n_results=limit
|
538 |
-
)
|
539 |
-
|
540 |
-
if not results["ids"][0]:
|
541 |
-
return []
|
542 |
-
|
543 |
-
# Format results
|
544 |
-
formatted_results = []
|
545 |
-
seen_docs = set()
|
546 |
-
|
547 |
-
for i, doc_id in enumerate(results["ids"][0]):
|
548 |
-
metadata = results["metadatas"][0][i]
|
549 |
-
|
550 |
-
# Deduplicate by document
|
551 |
-
if metadata["doc_id"] not in seen_docs:
|
552 |
-
seen_docs.add(metadata["doc_id"])
|
553 |
-
formatted_results.append({
|
554 |
-
"doc_id": metadata["doc_id"],
|
555 |
-
"source": metadata.get("source", "Unknown"),
|
556 |
-
"tags": metadata.get("tags", "").split(", "),
|
557 |
-
"summary": metadata.get("summary", ""),
|
558 |
-
"relevance_score": 1 - results["distances"][0][i],
|
559 |
-
"processed_at": metadata.get("processed_at", "")
|
560 |
-
})
|
561 |
-
|
562 |
-
return formatted_results
|
563 |
-
|
564 |
-
except Exception as e:
|
565 |
-
logger.error(f"Error searching knowledge base: {str(e)}")
|
566 |
-
return []
|
567 |
-
|
568 |
-
async def get_document_details(doc_id: str) -> Dict[str, Any]:
|
569 |
-
"""Get detailed information about a document"""
|
570 |
-
try:
|
571 |
-
results = collection.get(
|
572 |
-
where={"doc_id": doc_id},
|
573 |
-
limit=1
|
574 |
-
)
|
575 |
-
|
576 |
-
if not results["ids"]:
|
577 |
-
return {"error": "Document not found"}
|
578 |
-
|
579 |
-
metadata = results["metadatas"][0]
|
580 |
-
return {
|
581 |
-
"doc_id": doc_id,
|
582 |
-
"source": metadata.get("source", "Unknown"),
|
583 |
-
"tags": metadata.get("tags", "").split(", "),
|
584 |
-
"summary": metadata.get("summary", ""),
|
585 |
-
"processed_at": metadata.get("processed_at", ""),
|
586 |
-
"file_type": metadata.get("file_type", ""),
|
587 |
-
"content_preview": results["documents"][0][:500] + "..."
|
588 |
-
}
|
589 |
-
|
590 |
-
except Exception as e:
|
591 |
-
logger.error(f"Error getting document details: {str(e)}")
|
592 |
-
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mcp_tools/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# MCP tools module initialization
|
mcp_tools/generative_tool.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
import asyncio
|
4 |
+
|
5 |
+
from services.llm_service import LLMService
|
6 |
+
from mcp_tools.search_tool import SearchTool
|
7 |
+
from core.models import SearchResult
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
class GenerativeTool:
|
12 |
+
def __init__(self, llm_service: LLMService, search_tool: Optional[SearchTool] = None):
|
13 |
+
self.llm_service = llm_service
|
14 |
+
self.search_tool = search_tool
|
15 |
+
|
16 |
+
async def summarize(self, content: str, style: str = "concise", max_length: Optional[int] = None) -> str:
|
17 |
+
"""Generate a summary of the given content"""
|
18 |
+
try:
|
19 |
+
if not content.strip():
|
20 |
+
return "No content provided for summarization."
|
21 |
+
|
22 |
+
logger.info(f"Generating {style} summary for content of length {len(content)}")
|
23 |
+
|
24 |
+
summary = await self.llm_service.summarize(content, style, max_length)
|
25 |
+
|
26 |
+
logger.info(f"Generated summary of length {len(summary)}")
|
27 |
+
return summary
|
28 |
+
|
29 |
+
except Exception as e:
|
30 |
+
logger.error(f"Error generating summary: {str(e)}")
|
31 |
+
return f"Error generating summary: {str(e)}"
|
32 |
+
|
33 |
+
async def generate_tags(self, content: str, max_tags: int = 5) -> List[str]:
|
34 |
+
"""Generate relevant tags for the given content"""
|
35 |
+
try:
|
36 |
+
if not content.strip():
|
37 |
+
return []
|
38 |
+
|
39 |
+
logger.info(f"Generating up to {max_tags} tags for content")
|
40 |
+
|
41 |
+
tags = await self.llm_service.generate_tags(content, max_tags)
|
42 |
+
|
43 |
+
logger.info(f"Generated {len(tags)} tags")
|
44 |
+
return tags
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Error generating tags: {str(e)}")
|
48 |
+
return []
|
49 |
+
|
50 |
+
async def categorize(self, content: str, categories: List[str]) -> str:
|
51 |
+
"""Categorize content into one of the provided categories"""
|
52 |
+
try:
|
53 |
+
if not content.strip():
|
54 |
+
return "Uncategorized"
|
55 |
+
|
56 |
+
if not categories:
|
57 |
+
categories = ["Technology", "Business", "Science", "Education", "Entertainment", "News", "Research", "Other"]
|
58 |
+
|
59 |
+
logger.info(f"Categorizing content into one of {len(categories)} categories")
|
60 |
+
|
61 |
+
category = await self.llm_service.categorize(content, categories)
|
62 |
+
|
63 |
+
logger.info(f"Categorized as: {category}")
|
64 |
+
return category
|
65 |
+
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Error categorizing content: {str(e)}")
|
68 |
+
return "Uncategorized"
|
69 |
+
|
70 |
+
async def answer_question(self, question: str, context_results: List[SearchResult] = None) -> str:
|
71 |
+
"""Answer a question using the provided context or RAG"""
|
72 |
+
try:
|
73 |
+
if not question.strip():
|
74 |
+
return "No question provided."
|
75 |
+
|
76 |
+
logger.info(f"Answering question: {question[:100]}...")
|
77 |
+
|
78 |
+
# If no context provided and search tool is available, search for relevant context
|
79 |
+
if not context_results and self.search_tool:
|
80 |
+
logger.info("No context provided, searching for relevant information")
|
81 |
+
context_results = await self.search_tool.search(question, top_k=5)
|
82 |
+
|
83 |
+
# Prepare context from search results
|
84 |
+
if context_results:
|
85 |
+
context_texts = []
|
86 |
+
for result in context_results:
|
87 |
+
context_texts.append(f"Source: {result.document_id}\nContent: {result.content}\n")
|
88 |
+
|
89 |
+
context = "\n---\n".join(context_texts)
|
90 |
+
logger.info(f"Using context from {len(context_results)} sources")
|
91 |
+
else:
|
92 |
+
context = ""
|
93 |
+
logger.info("No context available for answering question")
|
94 |
+
|
95 |
+
# Generate answer
|
96 |
+
answer = await self.llm_service.answer_question(question, context)
|
97 |
+
|
98 |
+
logger.info(f"Generated answer of length {len(answer)}")
|
99 |
+
return answer
|
100 |
+
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error answering question: {str(e)}")
|
103 |
+
return f"I encountered an error while trying to answer your question: {str(e)}"
|
104 |
+
|
105 |
+
async def generate_outline(self, topic: str, num_sections: int = 5, detail_level: str = "medium") -> str:
|
106 |
+
"""Generate an outline for the given topic"""
|
107 |
+
try:
|
108 |
+
if not topic.strip():
|
109 |
+
return "No topic provided."
|
110 |
+
|
111 |
+
detail_descriptions = {
|
112 |
+
"brief": "brief bullet points",
|
113 |
+
"medium": "detailed bullet points with descriptions",
|
114 |
+
"detailed": "comprehensive outline with sub-sections and explanations"
|
115 |
+
}
|
116 |
+
|
117 |
+
detail_desc = detail_descriptions.get(detail_level, "detailed bullet points")
|
118 |
+
|
119 |
+
prompt = f"""Create a {detail_desc} outline for the topic: "{topic}"
|
120 |
+
|
121 |
+
The outline should have {num_sections} main sections and be well-structured and informative.
|
122 |
+
|
123 |
+
Format the outline clearly with proper numbering and indentation.
|
124 |
+
|
125 |
+
Topic: {topic}
|
126 |
+
|
127 |
+
Outline:"""
|
128 |
+
|
129 |
+
outline = await self.llm_service.generate_text(prompt, max_tokens=800, temperature=0.7)
|
130 |
+
|
131 |
+
logger.info(f"Generated outline for topic: {topic}")
|
132 |
+
return outline
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"Error generating outline: {str(e)}")
|
136 |
+
return f"Error generating outline: {str(e)}"
|
137 |
+
|
138 |
+
async def explain_concept(self, concept: str, audience: str = "general", length: str = "medium") -> str:
|
139 |
+
"""Explain a concept for a specific audience"""
|
140 |
+
try:
|
141 |
+
if not concept.strip():
|
142 |
+
return "No concept provided."
|
143 |
+
|
144 |
+
audience_styles = {
|
145 |
+
"general": "a general audience using simple, clear language",
|
146 |
+
"technical": "a technical audience with appropriate jargon and detail",
|
147 |
+
"beginner": "beginners with no prior knowledge, using analogies and examples",
|
148 |
+
"expert": "experts in the field with advanced terminology and depth"
|
149 |
+
}
|
150 |
+
|
151 |
+
length_guidance = {
|
152 |
+
"brief": "Keep the explanation concise and to the point (2-3 paragraphs).",
|
153 |
+
"medium": "Provide a comprehensive explanation (4-6 paragraphs).",
|
154 |
+
"detailed": "Give a thorough, in-depth explanation with examples."
|
155 |
+
}
|
156 |
+
|
157 |
+
audience_desc = audience_styles.get(audience, "a general audience")
|
158 |
+
length_desc = length_guidance.get(length, "Provide a comprehensive explanation.")
|
159 |
+
|
160 |
+
prompt = f"""Explain the concept of "{concept}" for {audience_desc}.
|
161 |
+
|
162 |
+
{length_desc}
|
163 |
+
|
164 |
+
Make sure to:
|
165 |
+
- Use appropriate language for the audience
|
166 |
+
- Include relevant examples or analogies
|
167 |
+
- Structure the explanation logically
|
168 |
+
- Ensure clarity and accuracy
|
169 |
+
|
170 |
+
Concept to explain: {concept}
|
171 |
+
|
172 |
+
Explanation:"""
|
173 |
+
|
174 |
+
explanation = await self.llm_service.generate_text(prompt, max_tokens=600, temperature=0.5)
|
175 |
+
|
176 |
+
logger.info(f"Generated explanation for concept: {concept}")
|
177 |
+
return explanation
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Error explaining concept: {str(e)}")
|
181 |
+
return f"Error explaining concept: {str(e)}"
|
182 |
+
|
183 |
+
async def compare_concepts(self, concept1: str, concept2: str, aspects: List[str] = None) -> str:
|
184 |
+
"""Compare two concepts across specified aspects"""
|
185 |
+
try:
|
186 |
+
if not concept1.strip() or not concept2.strip():
|
187 |
+
return "Both concepts must be provided for comparison."
|
188 |
+
|
189 |
+
if not aspects:
|
190 |
+
aspects = ["definition", "key features", "advantages", "disadvantages", "use cases"]
|
191 |
+
|
192 |
+
aspects_str = ", ".join(aspects)
|
193 |
+
|
194 |
+
prompt = f"""Compare and contrast "{concept1}" and "{concept2}" across the following aspects: {aspects_str}.
|
195 |
+
|
196 |
+
Structure your comparison clearly, addressing each aspect for both concepts.
|
197 |
+
|
198 |
+
Format:
|
199 |
+
## Comparison: {concept1} vs {concept2}
|
200 |
+
|
201 |
+
For each aspect, provide:
|
202 |
+
- **{concept1}**: [description]
|
203 |
+
- **{concept2}**: [description]
|
204 |
+
- **Key Difference**: [summary]
|
205 |
+
|
206 |
+
Concepts to compare:
|
207 |
+
1. {concept1}
|
208 |
+
2. {concept2}
|
209 |
+
|
210 |
+
Comparison:"""
|
211 |
+
|
212 |
+
comparison = await self.llm_service.generate_text(prompt, max_tokens=800, temperature=0.6)
|
213 |
+
|
214 |
+
logger.info(f"Generated comparison between {concept1} and {concept2}")
|
215 |
+
return comparison
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
logger.error(f"Error comparing concepts: {str(e)}")
|
219 |
+
return f"Error comparing concepts: {str(e)}"
|
220 |
+
|
221 |
+
async def generate_questions(self, content: str, question_type: str = "comprehension", num_questions: int = 5) -> List[str]:
|
222 |
+
"""Generate questions based on the provided content"""
|
223 |
+
try:
|
224 |
+
if not content.strip():
|
225 |
+
return []
|
226 |
+
|
227 |
+
question_types = {
|
228 |
+
"comprehension": "comprehension questions that test understanding of key concepts",
|
229 |
+
"analysis": "analytical questions that require deeper thinking and evaluation",
|
230 |
+
"application": "application questions that ask how to use the concepts in practice",
|
231 |
+
"creative": "creative questions that encourage original thinking and exploration",
|
232 |
+
"factual": "factual questions about specific details and information"
|
233 |
+
}
|
234 |
+
|
235 |
+
question_desc = question_types.get(question_type, "comprehension questions")
|
236 |
+
|
237 |
+
prompt = f"""Based on the following content, generate {num_questions} {question_desc}.
|
238 |
+
|
239 |
+
The questions should be:
|
240 |
+
- Clear and well-formulated
|
241 |
+
- Relevant to the content
|
242 |
+
- Appropriate for the specified type
|
243 |
+
- Engaging and thought-provoking
|
244 |
+
|
245 |
+
Content:
|
246 |
+
{content[:2000]} # Limit content length
|
247 |
+
|
248 |
+
Questions:"""
|
249 |
+
|
250 |
+
response = await self.llm_service.generate_text(prompt, max_tokens=400, temperature=0.7)
|
251 |
+
|
252 |
+
# Parse questions from response
|
253 |
+
questions = []
|
254 |
+
lines = response.split('\n')
|
255 |
+
|
256 |
+
for line in lines:
|
257 |
+
line = line.strip()
|
258 |
+
if line and ('?' in line or line.startswith(('1.', '2.', '3.', '4.', '5.', '-', '*'))):
|
259 |
+
# Clean up the question
|
260 |
+
question = line.lstrip('0123456789.-* ').strip()
|
261 |
+
if question and '?' in question:
|
262 |
+
questions.append(question)
|
263 |
+
|
264 |
+
logger.info(f"Generated {len(questions)} {question_type} questions")
|
265 |
+
return questions[:num_questions]
|
266 |
+
|
267 |
+
except Exception as e:
|
268 |
+
logger.error(f"Error generating questions: {str(e)}")
|
269 |
+
return []
|
270 |
+
|
271 |
+
async def paraphrase_text(self, text: str, style: str = "formal", preserve_meaning: bool = True) -> str:
|
272 |
+
"""Paraphrase text in a different style while preserving meaning"""
|
273 |
+
try:
|
274 |
+
if not text.strip():
|
275 |
+
return "No text provided for paraphrasing."
|
276 |
+
|
277 |
+
style_instructions = {
|
278 |
+
"formal": "formal, professional language",
|
279 |
+
"casual": "casual, conversational language",
|
280 |
+
"academic": "academic, scholarly language",
|
281 |
+
"simple": "simple, easy-to-understand language",
|
282 |
+
"technical": "technical, precise language"
|
283 |
+
}
|
284 |
+
|
285 |
+
style_desc = style_instructions.get(style, "clear, appropriate language")
|
286 |
+
meaning_instruction = "while preserving the exact meaning and key information" if preserve_meaning else "while maintaining the general intent"
|
287 |
+
|
288 |
+
prompt = f"""Paraphrase the following text using {style_desc} {meaning_instruction}.
|
289 |
+
|
290 |
+
Original text:
|
291 |
+
{text}
|
292 |
+
|
293 |
+
Paraphrased text:"""
|
294 |
+
|
295 |
+
paraphrase = await self.llm_service.generate_text(prompt, max_tokens=len(text.split()) * 2, temperature=0.6)
|
296 |
+
|
297 |
+
logger.info(f"Paraphrased text in {style} style")
|
298 |
+
return paraphrase.strip()
|
299 |
+
|
300 |
+
except Exception as e:
|
301 |
+
logger.error(f"Error paraphrasing text: {str(e)}")
|
302 |
+
return f"Error paraphrasing text: {str(e)}"
|
303 |
+
|
304 |
+
async def extract_key_insights(self, content: str, num_insights: int = 5) -> List[str]:
|
305 |
+
"""Extract key insights from the provided content"""
|
306 |
+
try:
|
307 |
+
if not content.strip():
|
308 |
+
return []
|
309 |
+
|
310 |
+
prompt = f"""Analyze the following content and extract {num_insights} key insights or takeaways.
|
311 |
+
|
312 |
+
Each insight should be:
|
313 |
+
- A clear, concise statement
|
314 |
+
- Significant and meaningful
|
315 |
+
- Based on the content provided
|
316 |
+
- Actionable or thought-provoking when possible
|
317 |
+
|
318 |
+
Content:
|
319 |
+
{content[:3000]} # Limit content length
|
320 |
+
|
321 |
+
Key Insights:"""
|
322 |
+
|
323 |
+
response = await self.llm_service.generate_text(prompt, max_tokens=400, temperature=0.6)
|
324 |
+
|
325 |
+
# Parse insights from response
|
326 |
+
insights = []
|
327 |
+
lines = response.split('\n')
|
328 |
+
|
329 |
+
for line in lines:
|
330 |
+
line = line.strip()
|
331 |
+
if line and (line.startswith(('1.', '2.', '3.', '4.', '5.', '-', '*')) or len(insights) == 0):
|
332 |
+
# Clean up the insight
|
333 |
+
insight = line.lstrip('0123456789.-* ').strip()
|
334 |
+
if insight and len(insight) > 10: # Minimum insight length
|
335 |
+
insights.append(insight)
|
336 |
+
|
337 |
+
logger.info(f"Extracted {len(insights)} key insights")
|
338 |
+
return insights[:num_insights]
|
339 |
+
|
340 |
+
except Exception as e:
|
341 |
+
logger.error(f"Error extracting insights: {str(e)}")
|
342 |
+
return []
|
mcp_tools/ingestion_tool.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
from typing import Dict, Any, Optional
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
from pathlib import Path
|
7 |
+
import uuid
|
8 |
+
|
9 |
+
from core.document_parser import DocumentParser
|
10 |
+
from core.chunker import TextChunker
|
11 |
+
from core.text_preprocessor import TextPreprocessor
|
12 |
+
from services.vector_store_service import VectorStoreService
|
13 |
+
from services.document_store_service import DocumentStoreService
|
14 |
+
from services.embedding_service import EmbeddingService
|
15 |
+
from services.ocr_service import OCRService
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
class IngestionTool:
|
20 |
+
def __init__(self, vector_store: VectorStoreService, document_store: DocumentStoreService,
|
21 |
+
embedding_service: EmbeddingService, ocr_service: OCRService):
|
22 |
+
self.vector_store = vector_store
|
23 |
+
self.document_store = document_store
|
24 |
+
self.embedding_service = embedding_service
|
25 |
+
self.ocr_service = ocr_service
|
26 |
+
|
27 |
+
self.document_parser = DocumentParser()
|
28 |
+
# Pass OCR service to document parser
|
29 |
+
self.document_parser.ocr_service = ocr_service
|
30 |
+
|
31 |
+
self.text_chunker = TextChunker()
|
32 |
+
self.text_preprocessor = TextPreprocessor()
|
33 |
+
|
34 |
+
async def process_document(self, file_path: str, file_type: str, task_id: Optional[str] = None) -> Dict[str, Any]:
|
35 |
+
"""Process a document through the full ingestion pipeline"""
|
36 |
+
if task_id is None:
|
37 |
+
task_id = str(uuid.uuid4())
|
38 |
+
|
39 |
+
try:
|
40 |
+
logger.info(f"Starting document processing for {file_path}")
|
41 |
+
|
42 |
+
# Step 1: Parse the document
|
43 |
+
filename = Path(file_path).name
|
44 |
+
document = await self.document_parser.parse_document(file_path, filename)
|
45 |
+
|
46 |
+
if not document.content:
|
47 |
+
logger.warning(f"No content extracted from document {filename}")
|
48 |
+
return {
|
49 |
+
"success": False,
|
50 |
+
"error": "No content could be extracted from the document",
|
51 |
+
"task_id": task_id
|
52 |
+
}
|
53 |
+
|
54 |
+
# Step 2: Store the document
|
55 |
+
await self.document_store.store_document(document)
|
56 |
+
|
57 |
+
# Step 3: Process content for embeddings
|
58 |
+
chunks = await self._create_and_embed_chunks(document)
|
59 |
+
|
60 |
+
if not chunks:
|
61 |
+
logger.warning(f"No chunks created for document {document.id}")
|
62 |
+
return {
|
63 |
+
"success": False,
|
64 |
+
"error": "Failed to create text chunks",
|
65 |
+
"task_id": task_id,
|
66 |
+
"document_id": document.id
|
67 |
+
}
|
68 |
+
|
69 |
+
# Step 4: Store embeddings
|
70 |
+
success = await self.vector_store.add_chunks(chunks)
|
71 |
+
|
72 |
+
if not success:
|
73 |
+
logger.error(f"Failed to store embeddings for document {document.id}")
|
74 |
+
return {
|
75 |
+
"success": False,
|
76 |
+
"error": "Failed to store embeddings",
|
77 |
+
"task_id": task_id,
|
78 |
+
"document_id": document.id
|
79 |
+
}
|
80 |
+
|
81 |
+
logger.info(f"Successfully processed document {document.id} with {len(chunks)} chunks")
|
82 |
+
|
83 |
+
return {
|
84 |
+
"success": True,
|
85 |
+
"task_id": task_id,
|
86 |
+
"document_id": document.id,
|
87 |
+
"filename": document.filename,
|
88 |
+
"chunks_created": len(chunks),
|
89 |
+
"content_length": len(document.content),
|
90 |
+
"doc_type": document.doc_type.value,
|
91 |
+
"message": f"Successfully processed {filename}"
|
92 |
+
}
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
logger.error(f"Error processing document {file_path}: {str(e)}")
|
96 |
+
return {
|
97 |
+
"success": False,
|
98 |
+
"error": str(e),
|
99 |
+
"task_id": task_id,
|
100 |
+
"message": f"Failed to process document: {str(e)}"
|
101 |
+
}
|
102 |
+
|
103 |
+
async def _create_and_embed_chunks(self, document) -> list:
|
104 |
+
"""Create chunks and generate embeddings"""
|
105 |
+
try:
|
106 |
+
# Step 1: Create chunks
|
107 |
+
chunks = self.text_chunker.chunk_document(
|
108 |
+
document.id,
|
109 |
+
document.content,
|
110 |
+
method="recursive"
|
111 |
+
)
|
112 |
+
|
113 |
+
if not chunks:
|
114 |
+
return []
|
115 |
+
|
116 |
+
# Step 2: Optimize chunks for embedding
|
117 |
+
optimized_chunks = self.text_chunker.optimize_chunks_for_embedding(chunks)
|
118 |
+
|
119 |
+
# Step 3: Generate embeddings
|
120 |
+
texts = [chunk.content for chunk in optimized_chunks]
|
121 |
+
embeddings = await self.embedding_service.generate_embeddings(texts)
|
122 |
+
|
123 |
+
# Step 4: Add embeddings to chunks
|
124 |
+
embedded_chunks = []
|
125 |
+
for i, chunk in enumerate(optimized_chunks):
|
126 |
+
if i < len(embeddings):
|
127 |
+
chunk.embedding = embeddings[i]
|
128 |
+
embedded_chunks.append(chunk)
|
129 |
+
|
130 |
+
return embedded_chunks
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"Error creating and embedding chunks: {str(e)}")
|
134 |
+
return []
|
135 |
+
|
136 |
+
async def process_url(self, url: str, task_id: Optional[str] = None) -> Dict[str, Any]:
|
137 |
+
"""Process a document from a URL"""
|
138 |
+
try:
|
139 |
+
import requests
|
140 |
+
from urllib.parse import urlparse
|
141 |
+
|
142 |
+
# Download the file
|
143 |
+
response = requests.get(url, timeout=30)
|
144 |
+
response.raise_for_status()
|
145 |
+
|
146 |
+
# Determine file type from URL or content-type
|
147 |
+
parsed_url = urlparse(url)
|
148 |
+
filename = Path(parsed_url.path).name or "downloaded_file"
|
149 |
+
|
150 |
+
# Create temporary file
|
151 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{filename}") as tmp_file:
|
152 |
+
tmp_file.write(response.content)
|
153 |
+
tmp_file_path = tmp_file.name
|
154 |
+
|
155 |
+
try:
|
156 |
+
# Process the downloaded file
|
157 |
+
result = await self.process_document(tmp_file_path, "", task_id)
|
158 |
+
result["source_url"] = url
|
159 |
+
return result
|
160 |
+
finally:
|
161 |
+
# Clean up temporary file
|
162 |
+
if os.path.exists(tmp_file_path):
|
163 |
+
os.unlink(tmp_file_path)
|
164 |
+
|
165 |
+
except Exception as e:
|
166 |
+
logger.error(f"Error processing URL {url}: {str(e)}")
|
167 |
+
return {
|
168 |
+
"success": False,
|
169 |
+
"error": str(e),
|
170 |
+
"task_id": task_id or str(uuid.uuid4()),
|
171 |
+
"source_url": url
|
172 |
+
}
|
173 |
+
|
174 |
+
async def process_text_content(self, content: str, filename: str = "text_content.txt",
|
175 |
+
task_id: Optional[str] = None) -> Dict[str, Any]:
|
176 |
+
"""Process raw text content directly"""
|
177 |
+
try:
|
178 |
+
from core.models import Document, DocumentType
|
179 |
+
from datetime import datetime
|
180 |
+
|
181 |
+
# Create document object
|
182 |
+
document = Document(
|
183 |
+
id=str(uuid.uuid4()),
|
184 |
+
filename=filename,
|
185 |
+
content=content,
|
186 |
+
doc_type=DocumentType.TEXT,
|
187 |
+
file_size=len(content.encode('utf-8')),
|
188 |
+
created_at=datetime.utcnow(),
|
189 |
+
metadata={
|
190 |
+
"source": "direct_text_input",
|
191 |
+
"content_length": len(content),
|
192 |
+
"word_count": len(content.split())
|
193 |
+
}
|
194 |
+
)
|
195 |
+
|
196 |
+
# Store the document
|
197 |
+
await self.document_store.store_document(document)
|
198 |
+
|
199 |
+
# Process content for embeddings
|
200 |
+
chunks = await self._create_and_embed_chunks(document)
|
201 |
+
|
202 |
+
if chunks:
|
203 |
+
await self.vector_store.add_chunks(chunks)
|
204 |
+
|
205 |
+
return {
|
206 |
+
"success": True,
|
207 |
+
"task_id": task_id or str(uuid.uuid4()),
|
208 |
+
"document_id": document.id,
|
209 |
+
"filename": filename,
|
210 |
+
"chunks_created": len(chunks),
|
211 |
+
"content_length": len(content),
|
212 |
+
"message": f"Successfully processed text content"
|
213 |
+
}
|
214 |
+
|
215 |
+
except Exception as e:
|
216 |
+
logger.error(f"Error processing text content: {str(e)}")
|
217 |
+
return {
|
218 |
+
"success": False,
|
219 |
+
"error": str(e),
|
220 |
+
"task_id": task_id or str(uuid.uuid4())
|
221 |
+
}
|
222 |
+
|
223 |
+
async def reprocess_document(self, document_id: str, task_id: Optional[str] = None) -> Dict[str, Any]:
|
224 |
+
"""Reprocess an existing document (useful for updating embeddings)"""
|
225 |
+
try:
|
226 |
+
# Get the document
|
227 |
+
document = await self.document_store.get_document(document_id)
|
228 |
+
|
229 |
+
if not document:
|
230 |
+
return {
|
231 |
+
"success": False,
|
232 |
+
"error": f"Document {document_id} not found",
|
233 |
+
"task_id": task_id or str(uuid.uuid4())
|
234 |
+
}
|
235 |
+
|
236 |
+
# Remove existing chunks from vector store
|
237 |
+
await self.vector_store.delete_document(document_id)
|
238 |
+
|
239 |
+
# Recreate and embed chunks
|
240 |
+
chunks = await self._create_and_embed_chunks(document)
|
241 |
+
|
242 |
+
if chunks:
|
243 |
+
await self.vector_store.add_chunks(chunks)
|
244 |
+
|
245 |
+
return {
|
246 |
+
"success": True,
|
247 |
+
"task_id": task_id or str(uuid.uuid4()),
|
248 |
+
"document_id": document_id,
|
249 |
+
"filename": document.filename,
|
250 |
+
"chunks_created": len(chunks),
|
251 |
+
"message": f"Successfully reprocessed {document.filename}"
|
252 |
+
}
|
253 |
+
|
254 |
+
except Exception as e:
|
255 |
+
logger.error(f"Error reprocessing document {document_id}: {str(e)}")
|
256 |
+
return {
|
257 |
+
"success": False,
|
258 |
+
"error": str(e),
|
259 |
+
"task_id": task_id or str(uuid.uuid4()),
|
260 |
+
"document_id": document_id
|
261 |
+
}
|
262 |
+
|
263 |
+
async def batch_process_directory(self, directory_path: str, task_id: Optional[str] = None) -> Dict[str, Any]:
|
264 |
+
"""Process multiple documents from a directory"""
|
265 |
+
try:
|
266 |
+
directory = Path(directory_path)
|
267 |
+
if not directory.exists() or not directory.is_dir():
|
268 |
+
return {
|
269 |
+
"success": False,
|
270 |
+
"error": f"Directory {directory_path} does not exist",
|
271 |
+
"task_id": task_id or str(uuid.uuid4())
|
272 |
+
}
|
273 |
+
|
274 |
+
# Supported file extensions
|
275 |
+
supported_extensions = {'.txt', '.pdf', '.docx', '.png', '.jpg', '.jpeg', '.bmp', '.tiff'}
|
276 |
+
|
277 |
+
# Find all supported files
|
278 |
+
files_to_process = []
|
279 |
+
for ext in supported_extensions:
|
280 |
+
files_to_process.extend(directory.glob(f"*{ext}"))
|
281 |
+
files_to_process.extend(directory.glob(f"*{ext.upper()}"))
|
282 |
+
|
283 |
+
if not files_to_process:
|
284 |
+
return {
|
285 |
+
"success": False,
|
286 |
+
"error": "No supported files found in directory",
|
287 |
+
"task_id": task_id or str(uuid.uuid4())
|
288 |
+
}
|
289 |
+
|
290 |
+
# Process files
|
291 |
+
results = []
|
292 |
+
successful = 0
|
293 |
+
failed = 0
|
294 |
+
|
295 |
+
for file_path in files_to_process:
|
296 |
+
try:
|
297 |
+
result = await self.process_document(str(file_path), file_path.suffix)
|
298 |
+
results.append(result)
|
299 |
+
|
300 |
+
if result.get("success"):
|
301 |
+
successful += 1
|
302 |
+
else:
|
303 |
+
failed += 1
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
failed += 1
|
307 |
+
results.append({
|
308 |
+
"success": False,
|
309 |
+
"error": str(e),
|
310 |
+
"filename": file_path.name
|
311 |
+
})
|
312 |
+
|
313 |
+
return {
|
314 |
+
"success": True,
|
315 |
+
"task_id": task_id or str(uuid.uuid4()),
|
316 |
+
"directory": str(directory),
|
317 |
+
"total_files": len(files_to_process),
|
318 |
+
"successful": successful,
|
319 |
+
"failed": failed,
|
320 |
+
"results": results,
|
321 |
+
"message": f"Processed {successful}/{len(files_to_process)} files successfully"
|
322 |
+
}
|
323 |
+
|
324 |
+
except Exception as e:
|
325 |
+
logger.error(f"Error batch processing directory {directory_path}: {str(e)}")
|
326 |
+
return {
|
327 |
+
"success": False,
|
328 |
+
"error": str(e),
|
329 |
+
"task_id": task_id or str(uuid.uuid4())
|
330 |
+
}
|
mcp_tools/search_tool.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List, Dict, Any, Optional
|
3 |
+
import asyncio
|
4 |
+
|
5 |
+
from core.models import SearchResult
|
6 |
+
from services.vector_store_service import VectorStoreService
|
7 |
+
from services.embedding_service import EmbeddingService
|
8 |
+
from services.document_store_service import DocumentStoreService
|
9 |
+
import config
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
class SearchTool:
|
14 |
+
def __init__(self, vector_store: VectorStoreService, embedding_service: EmbeddingService,
|
15 |
+
document_store: Optional[DocumentStoreService] = None):
|
16 |
+
self.vector_store = vector_store
|
17 |
+
self.embedding_service = embedding_service
|
18 |
+
self.document_store = document_store
|
19 |
+
self.config = config.config
|
20 |
+
|
21 |
+
async def search(self, query: str, top_k: int = 5, filters: Optional[Dict[str, Any]] = None,
|
22 |
+
similarity_threshold: Optional[float] = None) -> List[SearchResult]:
|
23 |
+
"""Perform semantic search"""
|
24 |
+
try:
|
25 |
+
if not query.strip():
|
26 |
+
logger.warning("Empty search query provided")
|
27 |
+
return []
|
28 |
+
|
29 |
+
# Use default threshold if not provided
|
30 |
+
if similarity_threshold is None:
|
31 |
+
similarity_threshold = self.config.SIMILARITY_THRESHOLD
|
32 |
+
|
33 |
+
logger.info(f"Performing semantic search for: '{query}' (top_k={top_k})")
|
34 |
+
|
35 |
+
# Generate query embedding
|
36 |
+
query_embedding = await self.embedding_service.generate_single_embedding(query)
|
37 |
+
|
38 |
+
if not query_embedding:
|
39 |
+
logger.error("Failed to generate query embedding")
|
40 |
+
return []
|
41 |
+
|
42 |
+
# Perform vector search
|
43 |
+
results = await self.vector_store.search(
|
44 |
+
query_embedding=query_embedding,
|
45 |
+
top_k=top_k,
|
46 |
+
filters=filters
|
47 |
+
)
|
48 |
+
|
49 |
+
# Filter by similarity threshold
|
50 |
+
filtered_results = [
|
51 |
+
result for result in results
|
52 |
+
if result.score >= similarity_threshold
|
53 |
+
]
|
54 |
+
|
55 |
+
logger.info(f"Found {len(filtered_results)} results above threshold {similarity_threshold}")
|
56 |
+
|
57 |
+
# Enhance results with additional metadata if document store is available
|
58 |
+
if self.document_store:
|
59 |
+
enhanced_results = await self._enhance_results_with_metadata(filtered_results)
|
60 |
+
return enhanced_results
|
61 |
+
|
62 |
+
return filtered_results
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Error performing semantic search: {str(e)}")
|
66 |
+
return []
|
67 |
+
|
68 |
+
async def _enhance_results_with_metadata(self, results: List[SearchResult]) -> List[SearchResult]:
|
69 |
+
"""Enhance search results with document metadata"""
|
70 |
+
try:
|
71 |
+
enhanced_results = []
|
72 |
+
|
73 |
+
for result in results:
|
74 |
+
try:
|
75 |
+
# Get document metadata
|
76 |
+
document = await self.document_store.get_document(result.document_id)
|
77 |
+
|
78 |
+
if document:
|
79 |
+
# Add document metadata to result
|
80 |
+
enhanced_metadata = {
|
81 |
+
**result.metadata,
|
82 |
+
"document_filename": document.filename,
|
83 |
+
"document_type": document.doc_type.value,
|
84 |
+
"document_tags": document.tags,
|
85 |
+
"document_category": document.category,
|
86 |
+
"document_created_at": document.created_at.isoformat(),
|
87 |
+
"document_summary": document.summary
|
88 |
+
}
|
89 |
+
|
90 |
+
enhanced_result = SearchResult(
|
91 |
+
chunk_id=result.chunk_id,
|
92 |
+
document_id=result.document_id,
|
93 |
+
content=result.content,
|
94 |
+
score=result.score,
|
95 |
+
metadata=enhanced_metadata
|
96 |
+
)
|
97 |
+
|
98 |
+
enhanced_results.append(enhanced_result)
|
99 |
+
else:
|
100 |
+
# Document not found, use original result
|
101 |
+
enhanced_results.append(result)
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
logger.warning(f"Error enhancing result {result.chunk_id}: {str(e)}")
|
105 |
+
enhanced_results.append(result)
|
106 |
+
|
107 |
+
return enhanced_results
|
108 |
+
|
109 |
+
except Exception as e:
|
110 |
+
logger.error(f"Error enhancing results: {str(e)}")
|
111 |
+
return results
|
112 |
+
|
113 |
+
async def multi_query_search(self, queries: List[str], top_k: int = 5,
|
114 |
+
aggregate_method: str = "merge") -> List[SearchResult]:
|
115 |
+
"""Perform search with multiple queries and aggregate results"""
|
116 |
+
try:
|
117 |
+
all_results = []
|
118 |
+
|
119 |
+
# Perform search for each query
|
120 |
+
for query in queries:
|
121 |
+
if query.strip():
|
122 |
+
query_results = await self.search(query, top_k)
|
123 |
+
all_results.extend(query_results)
|
124 |
+
|
125 |
+
if not all_results:
|
126 |
+
return []
|
127 |
+
|
128 |
+
# Aggregate results
|
129 |
+
if aggregate_method == "merge":
|
130 |
+
return await self._merge_results(all_results, top_k)
|
131 |
+
elif aggregate_method == "intersect":
|
132 |
+
return await self._intersect_results(all_results, top_k)
|
133 |
+
elif aggregate_method == "average":
|
134 |
+
return await self._average_results(all_results, top_k)
|
135 |
+
else:
|
136 |
+
# Default to merge
|
137 |
+
return await self._merge_results(all_results, top_k)
|
138 |
+
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Error in multi-query search: {str(e)}")
|
141 |
+
return []
|
142 |
+
|
143 |
+
async def _merge_results(self, results: List[SearchResult], top_k: int) -> List[SearchResult]:
|
144 |
+
"""Merge results and remove duplicates, keeping highest scores"""
|
145 |
+
try:
|
146 |
+
# Group by chunk_id and keep highest score
|
147 |
+
chunk_scores = {}
|
148 |
+
chunk_results = {}
|
149 |
+
|
150 |
+
for result in results:
|
151 |
+
chunk_id = result.chunk_id
|
152 |
+
if chunk_id not in chunk_scores or result.score > chunk_scores[chunk_id]:
|
153 |
+
chunk_scores[chunk_id] = result.score
|
154 |
+
chunk_results[chunk_id] = result
|
155 |
+
|
156 |
+
# Sort by score and return top_k
|
157 |
+
merged_results = list(chunk_results.values())
|
158 |
+
merged_results.sort(key=lambda x: x.score, reverse=True)
|
159 |
+
|
160 |
+
return merged_results[:top_k]
|
161 |
+
|
162 |
+
except Exception as e:
|
163 |
+
logger.error(f"Error merging results: {str(e)}")
|
164 |
+
return results[:top_k]
|
165 |
+
|
166 |
+
async def _intersect_results(self, results: List[SearchResult], top_k: int) -> List[SearchResult]:
|
167 |
+
"""Find chunks that appear in multiple queries"""
|
168 |
+
try:
|
169 |
+
# Count occurrences of each chunk
|
170 |
+
chunk_counts = {}
|
171 |
+
chunk_results = {}
|
172 |
+
|
173 |
+
for result in results:
|
174 |
+
chunk_id = result.chunk_id
|
175 |
+
chunk_counts[chunk_id] = chunk_counts.get(chunk_id, 0) + 1
|
176 |
+
|
177 |
+
if chunk_id not in chunk_results or result.score > chunk_results[chunk_id].score:
|
178 |
+
chunk_results[chunk_id] = result
|
179 |
+
|
180 |
+
# Filter chunks that appear more than once
|
181 |
+
intersect_results = [
|
182 |
+
result for chunk_id, result in chunk_results.items()
|
183 |
+
if chunk_counts[chunk_id] > 1
|
184 |
+
]
|
185 |
+
|
186 |
+
# Sort by score
|
187 |
+
intersect_results.sort(key=lambda x: x.score, reverse=True)
|
188 |
+
|
189 |
+
return intersect_results[:top_k]
|
190 |
+
|
191 |
+
except Exception as e:
|
192 |
+
logger.error(f"Error intersecting results: {str(e)}")
|
193 |
+
return []
|
194 |
+
|
195 |
+
async def _average_results(self, results: List[SearchResult], top_k: int) -> List[SearchResult]:
|
196 |
+
"""Average scores for chunks that appear multiple times"""
|
197 |
+
try:
|
198 |
+
# Group by chunk_id and calculate average scores
|
199 |
+
chunk_groups = {}
|
200 |
+
|
201 |
+
for result in results:
|
202 |
+
chunk_id = result.chunk_id
|
203 |
+
if chunk_id not in chunk_groups:
|
204 |
+
chunk_groups[chunk_id] = []
|
205 |
+
chunk_groups[chunk_id].append(result)
|
206 |
+
|
207 |
+
# Calculate average scores
|
208 |
+
averaged_results = []
|
209 |
+
for chunk_id, group in chunk_groups.items():
|
210 |
+
avg_score = sum(r.score for r in group) / len(group)
|
211 |
+
|
212 |
+
# Use the result with the highest individual score but update the score to average
|
213 |
+
best_result = max(group, key=lambda x: x.score)
|
214 |
+
averaged_result = SearchResult(
|
215 |
+
chunk_id=best_result.chunk_id,
|
216 |
+
document_id=best_result.document_id,
|
217 |
+
content=best_result.content,
|
218 |
+
score=avg_score,
|
219 |
+
metadata={
|
220 |
+
**best_result.metadata,
|
221 |
+
"query_count": len(group),
|
222 |
+
"score_range": f"{min(r.score for r in group):.3f}-{max(r.score for r in group):.3f}"
|
223 |
+
}
|
224 |
+
)
|
225 |
+
averaged_results.append(averaged_result)
|
226 |
+
|
227 |
+
# Sort by average score
|
228 |
+
averaged_results.sort(key=lambda x: x.score, reverse=True)
|
229 |
+
|
230 |
+
return averaged_results[:top_k]
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
logger.error(f"Error averaging results: {str(e)}")
|
234 |
+
return results[:top_k]
|
235 |
+
|
236 |
+
async def search_by_document(self, document_id: str, query: str, top_k: int = 5) -> List[SearchResult]:
|
237 |
+
"""Search within a specific document"""
|
238 |
+
try:
|
239 |
+
filters = {"document_id": document_id}
|
240 |
+
return await self.search(query, top_k, filters)
|
241 |
+
|
242 |
+
except Exception as e:
|
243 |
+
logger.error(f"Error searching within document {document_id}: {str(e)}")
|
244 |
+
return []
|
245 |
+
|
246 |
+
async def search_by_category(self, category: str, query: str, top_k: int = 5) -> List[SearchResult]:
|
247 |
+
"""Search within documents of a specific category"""
|
248 |
+
try:
|
249 |
+
if not self.document_store:
|
250 |
+
logger.warning("Document store not available for category search")
|
251 |
+
return await self.search(query, top_k)
|
252 |
+
|
253 |
+
# Get documents in the category
|
254 |
+
documents = await self.document_store.list_documents(
|
255 |
+
limit=1000, # Adjust as needed
|
256 |
+
filters={"category": category}
|
257 |
+
)
|
258 |
+
|
259 |
+
if not documents:
|
260 |
+
logger.info(f"No documents found in category '{category}'")
|
261 |
+
return []
|
262 |
+
|
263 |
+
# Extract document IDs
|
264 |
+
document_ids = [doc.id for doc in documents]
|
265 |
+
|
266 |
+
# Search with document ID filter
|
267 |
+
filters = {"document_ids": document_ids}
|
268 |
+
return await self.search(query, top_k, filters)
|
269 |
+
|
270 |
+
except Exception as e:
|
271 |
+
logger.error(f"Error searching by category {category}: {str(e)}")
|
272 |
+
return []
|
273 |
+
|
274 |
+
async def search_with_date_range(self, query: str, start_date, end_date, top_k: int = 5) -> List[SearchResult]:
|
275 |
+
"""Search documents within a date range"""
|
276 |
+
try:
|
277 |
+
if not self.document_store:
|
278 |
+
logger.warning("Document store not available for date range search")
|
279 |
+
return await self.search(query, top_k)
|
280 |
+
|
281 |
+
# Get documents in the date range
|
282 |
+
documents = await self.document_store.list_documents(
|
283 |
+
limit=1000, # Adjust as needed
|
284 |
+
filters={
|
285 |
+
"created_after": start_date,
|
286 |
+
"created_before": end_date
|
287 |
+
}
|
288 |
+
)
|
289 |
+
|
290 |
+
if not documents:
|
291 |
+
logger.info(f"No documents found in date range")
|
292 |
+
return []
|
293 |
+
|
294 |
+
# Extract document IDs
|
295 |
+
document_ids = [doc.id for doc in documents]
|
296 |
+
|
297 |
+
# Search with document ID filter
|
298 |
+
filters = {"document_ids": document_ids}
|
299 |
+
return await self.search(query, top_k, filters)
|
300 |
+
|
301 |
+
except Exception as e:
|
302 |
+
logger.error(f"Error searching with date range: {str(e)}")
|
303 |
+
return []
|
304 |
+
|
305 |
+
async def get_search_suggestions(self, partial_query: str, limit: int = 5) -> List[str]:
|
306 |
+
"""Get search suggestions based on partial query"""
|
307 |
+
try:
|
308 |
+
# This is a simple implementation
|
309 |
+
# In a production system, you might want to use a more sophisticated approach
|
310 |
+
|
311 |
+
if len(partial_query) < 2:
|
312 |
+
return []
|
313 |
+
|
314 |
+
# Search for the partial query
|
315 |
+
results = await self.search(partial_query, top_k=20)
|
316 |
+
|
317 |
+
# Extract potential query expansions from content
|
318 |
+
suggestions = set()
|
319 |
+
|
320 |
+
for result in results:
|
321 |
+
content_words = result.content.lower().split()
|
322 |
+
for i, word in enumerate(content_words):
|
323 |
+
if partial_query.lower() in word:
|
324 |
+
# Add the word itself
|
325 |
+
suggestions.add(word.strip('.,!?;:'))
|
326 |
+
|
327 |
+
# Add phrases that include this word
|
328 |
+
if i > 0:
|
329 |
+
phrase = f"{content_words[i-1]} {word}".strip('.,!?;:')
|
330 |
+
suggestions.add(phrase)
|
331 |
+
if i < len(content_words) - 1:
|
332 |
+
phrase = f"{word} {content_words[i+1]}".strip('.,!?;:')
|
333 |
+
suggestions.add(phrase)
|
334 |
+
|
335 |
+
# Filter and sort suggestions
|
336 |
+
filtered_suggestions = [
|
337 |
+
s for s in suggestions
|
338 |
+
if len(s) > len(partial_query) and s.startswith(partial_query.lower())
|
339 |
+
]
|
340 |
+
|
341 |
+
return sorted(filtered_suggestions)[:limit]
|
342 |
+
|
343 |
+
except Exception as e:
|
344 |
+
logger.error(f"Error getting search suggestions: {str(e)}")
|
345 |
+
return []
|
346 |
+
|
347 |
+
async def explain_search(self, query: str, top_k: int = 3) -> Dict[str, Any]:
|
348 |
+
"""Provide detailed explanation of search process and results"""
|
349 |
+
try:
|
350 |
+
explanation = {
|
351 |
+
"query": query,
|
352 |
+
"steps": [],
|
353 |
+
"results_analysis": {},
|
354 |
+
"performance_metrics": {}
|
355 |
+
}
|
356 |
+
|
357 |
+
# Step 1: Query processing
|
358 |
+
explanation["steps"].append({
|
359 |
+
"step": "query_processing",
|
360 |
+
"description": "Processing and normalizing the search query",
|
361 |
+
"details": {
|
362 |
+
"original_query": query,
|
363 |
+
"cleaned_query": query.strip(),
|
364 |
+
"query_length": len(query)
|
365 |
+
}
|
366 |
+
})
|
367 |
+
|
368 |
+
# Step 2: Embedding generation
|
369 |
+
import time
|
370 |
+
start_time = time.time()
|
371 |
+
|
372 |
+
query_embedding = await self.embedding_service.generate_single_embedding(query)
|
373 |
+
|
374 |
+
embedding_time = time.time() - start_time
|
375 |
+
|
376 |
+
explanation["steps"].append({
|
377 |
+
"step": "embedding_generation",
|
378 |
+
"description": "Converting query to vector embedding",
|
379 |
+
"details": {
|
380 |
+
"embedding_dimension": len(query_embedding) if query_embedding else 0,
|
381 |
+
"generation_time_ms": round(embedding_time * 1000, 2)
|
382 |
+
}
|
383 |
+
})
|
384 |
+
|
385 |
+
# Step 3: Vector search
|
386 |
+
start_time = time.time()
|
387 |
+
|
388 |
+
results = await self.vector_store.search(query_embedding, top_k)
|
389 |
+
|
390 |
+
search_time = time.time() - start_time
|
391 |
+
|
392 |
+
explanation["steps"].append({
|
393 |
+
"step": "vector_search",
|
394 |
+
"description": "Searching vector database for similar content",
|
395 |
+
"details": {
|
396 |
+
"search_time_ms": round(search_time * 1000, 2),
|
397 |
+
"results_found": len(results),
|
398 |
+
"top_score": results[0].score if results else 0,
|
399 |
+
"score_range": f"{min(r.score for r in results):.3f}-{max(r.score for r in results):.3f}" if results else "N/A"
|
400 |
+
}
|
401 |
+
})
|
402 |
+
|
403 |
+
# Results analysis
|
404 |
+
if results:
|
405 |
+
explanation["results_analysis"] = {
|
406 |
+
"total_results": len(results),
|
407 |
+
"average_score": sum(r.score for r in results) / len(results),
|
408 |
+
"unique_documents": len(set(r.document_id for r in results)),
|
409 |
+
"content_lengths": [len(r.content) for r in results]
|
410 |
+
}
|
411 |
+
|
412 |
+
# Performance metrics
|
413 |
+
explanation["performance_metrics"] = {
|
414 |
+
"total_time_ms": round((embedding_time + search_time) * 1000, 2),
|
415 |
+
"embedding_time_ms": round(embedding_time * 1000, 2),
|
416 |
+
"search_time_ms": round(search_time * 1000, 2)
|
417 |
+
}
|
418 |
+
|
419 |
+
return explanation
|
420 |
+
|
421 |
+
except Exception as e:
|
422 |
+
logger.error(f"Error explaining search: {str(e)}")
|
423 |
+
return {"error": str(e)}
|
mcp_tools/utils.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
import functools
|
4 |
+
from typing import Any, Callable, Dict, List, Optional
|
5 |
+
import time
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
def async_timer(func: Callable) -> Callable:
|
12 |
+
"""Decorator to time async function execution"""
|
13 |
+
@functools.wraps(func)
|
14 |
+
async def wrapper(*args, **kwargs):
|
15 |
+
start_time = time.time()
|
16 |
+
try:
|
17 |
+
result = await func(*args, **kwargs)
|
18 |
+
end_time = time.time()
|
19 |
+
logger.debug(f"{func.__name__} completed in {end_time - start_time:.3f}s")
|
20 |
+
return result
|
21 |
+
except Exception as e:
|
22 |
+
end_time = time.time()
|
23 |
+
logger.error(f"{func.__name__} failed after {end_time - start_time:.3f}s: {str(e)}")
|
24 |
+
raise
|
25 |
+
return wrapper
|
26 |
+
|
27 |
+
def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0):
|
28 |
+
"""Decorator to retry async functions with exponential backoff"""
|
29 |
+
def decorator(func: Callable) -> Callable:
|
30 |
+
@functools.wraps(func)
|
31 |
+
async def wrapper(*args, **kwargs):
|
32 |
+
attempt = 1
|
33 |
+
current_delay = delay
|
34 |
+
|
35 |
+
while attempt <= max_attempts:
|
36 |
+
try:
|
37 |
+
return await func(*args, **kwargs)
|
38 |
+
except Exception as e:
|
39 |
+
if attempt == max_attempts:
|
40 |
+
logger.error(f"{func.__name__} failed after {max_attempts} attempts: {str(e)}")
|
41 |
+
raise
|
42 |
+
|
43 |
+
logger.warning(f"{func.__name__} attempt {attempt} failed: {str(e)}")
|
44 |
+
logger.info(f"Retrying in {current_delay}s...")
|
45 |
+
|
46 |
+
await asyncio.sleep(current_delay)
|
47 |
+
attempt += 1
|
48 |
+
current_delay *= backoff
|
49 |
+
|
50 |
+
return wrapper
|
51 |
+
return decorator
|
52 |
+
|
53 |
+
class MCPToolResponse:
|
54 |
+
"""Standardized response format for MCP tools"""
|
55 |
+
|
56 |
+
def __init__(self, success: bool, data: Any = None, error: str = None,
|
57 |
+
metadata: Dict[str, Any] = None):
|
58 |
+
self.success = success
|
59 |
+
self.data = data
|
60 |
+
self.error = error
|
61 |
+
self.metadata = metadata or {}
|
62 |
+
self.timestamp = time.time()
|
63 |
+
|
64 |
+
def to_dict(self) -> Dict[str, Any]:
|
65 |
+
"""Convert response to dictionary"""
|
66 |
+
result = {
|
67 |
+
"success": self.success,
|
68 |
+
"timestamp": self.timestamp
|
69 |
+
}
|
70 |
+
|
71 |
+
if self.success:
|
72 |
+
result["data"] = self.data
|
73 |
+
else:
|
74 |
+
result["error"] = self.error
|
75 |
+
|
76 |
+
if self.metadata:
|
77 |
+
result["metadata"] = self.metadata
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
@classmethod
|
82 |
+
def success_response(cls, data: Any, metadata: Dict[str, Any] = None):
|
83 |
+
"""Create a success response"""
|
84 |
+
return cls(success=True, data=data, metadata=metadata)
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def error_response(cls, error: str, metadata: Dict[str, Any] = None):
|
88 |
+
"""Create an error response"""
|
89 |
+
return cls(success=False, error=error, metadata=metadata)
|
90 |
+
|
91 |
+
def validate_required_params(params: Dict[str, Any], required: List[str]) -> Optional[str]:
|
92 |
+
"""Validate that required parameters are present"""
|
93 |
+
missing = []
|
94 |
+
for param in required:
|
95 |
+
if param not in params or params[param] is None:
|
96 |
+
missing.append(param)
|
97 |
+
|
98 |
+
if missing:
|
99 |
+
return f"Missing required parameters: {', '.join(missing)}"
|
100 |
+
|
101 |
+
return None
|
102 |
+
|
103 |
+
def sanitize_filename(filename: str) -> str:
|
104 |
+
"""Sanitize filename for safe storage"""
|
105 |
+
import re
|
106 |
+
|
107 |
+
# Remove or replace invalid characters
|
108 |
+
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
109 |
+
|
110 |
+
# Remove leading/trailing dots and spaces
|
111 |
+
filename = filename.strip('. ')
|
112 |
+
|
113 |
+
# Limit length
|
114 |
+
if len(filename) > 255:
|
115 |
+
name, ext = Path(filename).stem, Path(filename).suffix
|
116 |
+
max_name_len = 255 - len(ext)
|
117 |
+
filename = name[:max_name_len] + ext
|
118 |
+
|
119 |
+
# Ensure not empty
|
120 |
+
if not filename:
|
121 |
+
filename = "unnamed_file"
|
122 |
+
|
123 |
+
return filename
|
124 |
+
|
125 |
+
def truncate_text(text: str, max_length: int, add_ellipsis: bool = True) -> str:
|
126 |
+
"""Truncate text to specified length"""
|
127 |
+
if len(text) <= max_length:
|
128 |
+
return text
|
129 |
+
|
130 |
+
if add_ellipsis and max_length > 3:
|
131 |
+
return text[:max_length - 3] + "..."
|
132 |
+
else:
|
133 |
+
return text[:max_length]
|
134 |
+
|
135 |
+
def extract_file_info(file_path: str) -> Dict[str, Any]:
|
136 |
+
"""Extract information about a file"""
|
137 |
+
try:
|
138 |
+
path = Path(file_path)
|
139 |
+
stat = path.stat()
|
140 |
+
|
141 |
+
return {
|
142 |
+
"filename": path.name,
|
143 |
+
"extension": path.suffix.lower(),
|
144 |
+
"size_bytes": stat.st_size,
|
145 |
+
"size_mb": round(stat.st_size / (1024 * 1024), 2),
|
146 |
+
"created_time": stat.st_ctime,
|
147 |
+
"modified_time": stat.st_mtime,
|
148 |
+
"exists": path.exists(),
|
149 |
+
"is_file": path.is_file(),
|
150 |
+
"is_dir": path.is_dir()
|
151 |
+
}
|
152 |
+
except Exception as e:
|
153 |
+
return {"error": str(e)}
|
154 |
+
|
155 |
+
async def batch_process(items: List[Any], processor: Callable, batch_size: int = 10,
|
156 |
+
max_concurrent: int = 5) -> List[Any]:
|
157 |
+
"""Process items in batches with concurrency control"""
|
158 |
+
results = []
|
159 |
+
semaphore = asyncio.Semaphore(max_concurrent)
|
160 |
+
|
161 |
+
async def process_item(item):
|
162 |
+
async with semaphore:
|
163 |
+
return await processor(item)
|
164 |
+
|
165 |
+
# Process in batches
|
166 |
+
for i in range(0, len(items), batch_size):
|
167 |
+
batch = items[i:i + batch_size]
|
168 |
+
batch_tasks = [process_item(item) for item in batch]
|
169 |
+
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
170 |
+
results.extend(batch_results)
|
171 |
+
|
172 |
+
return results
|
173 |
+
|
174 |
+
def format_file_size(size_bytes: int) -> str:
|
175 |
+
"""Format file size in human-readable format"""
|
176 |
+
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
177 |
+
if size_bytes < 1024.0:
|
178 |
+
return f"{size_bytes:.1f} {unit}"
|
179 |
+
size_bytes /= 1024.0
|
180 |
+
return f"{size_bytes:.1f} PB"
|
181 |
+
|
182 |
+
def calculate_reading_time(text: str, words_per_minute: int = 200) -> int:
|
183 |
+
"""Calculate estimated reading time in minutes"""
|
184 |
+
word_count = len(text.split())
|
185 |
+
return max(1, round(word_count / words_per_minute))
|
186 |
+
|
187 |
+
class ProgressTracker:
|
188 |
+
"""Track progress of long-running operations"""
|
189 |
+
|
190 |
+
def __init__(self, total_items: int, description: str = "Processing"):
|
191 |
+
self.total_items = total_items
|
192 |
+
self.completed_items = 0
|
193 |
+
self.description = description
|
194 |
+
self.start_time = time.time()
|
195 |
+
self.errors = []
|
196 |
+
|
197 |
+
def update(self, completed: int = 1, error: str = None):
|
198 |
+
"""Update progress"""
|
199 |
+
self.completed_items += completed
|
200 |
+
if error:
|
201 |
+
self.errors.append(error)
|
202 |
+
|
203 |
+
def get_progress(self) -> Dict[str, Any]:
|
204 |
+
"""Get current progress information"""
|
205 |
+
elapsed_time = time.time() - self.start_time
|
206 |
+
progress_percent = (self.completed_items / self.total_items) * 100 if self.total_items > 0 else 0
|
207 |
+
|
208 |
+
# Estimate remaining time
|
209 |
+
if self.completed_items > 0:
|
210 |
+
avg_time_per_item = elapsed_time / self.completed_items
|
211 |
+
remaining_items = self.total_items - self.completed_items
|
212 |
+
estimated_remaining_time = avg_time_per_item * remaining_items
|
213 |
+
else:
|
214 |
+
estimated_remaining_time = 0
|
215 |
+
|
216 |
+
return {
|
217 |
+
"description": self.description,
|
218 |
+
"total_items": self.total_items,
|
219 |
+
"completed_items": self.completed_items,
|
220 |
+
"progress_percent": round(progress_percent, 1),
|
221 |
+
"elapsed_time_seconds": round(elapsed_time, 1),
|
222 |
+
"estimated_remaining_seconds": round(estimated_remaining_time, 1),
|
223 |
+
"errors_count": len(self.errors),
|
224 |
+
"errors": self.errors[-5:] if self.errors else [] # Last 5 errors
|
225 |
+
}
|
226 |
+
|
227 |
+
def is_complete(self) -> bool:
|
228 |
+
"""Check if processing is complete"""
|
229 |
+
return self.completed_items >= self.total_items
|
230 |
+
|
231 |
+
def load_json_config(config_path: str, default_config: Dict[str, Any] = None) -> Dict[str, Any]:
|
232 |
+
"""Load configuration from JSON file with fallback to defaults"""
|
233 |
+
try:
|
234 |
+
with open(config_path, 'r') as f:
|
235 |
+
config = json.load(f)
|
236 |
+
logger.info(f"Loaded configuration from {config_path}")
|
237 |
+
return config
|
238 |
+
except FileNotFoundError:
|
239 |
+
logger.warning(f"Configuration file {config_path} not found, using defaults")
|
240 |
+
return default_config or {}
|
241 |
+
except json.JSONDecodeError as e:
|
242 |
+
logger.error(f"Invalid JSON in configuration file {config_path}: {str(e)}")
|
243 |
+
return default_config or {}
|
244 |
+
|
245 |
+
def save_json_config(config: Dict[str, Any], config_path: str) -> bool:
|
246 |
+
"""Save configuration to JSON file"""
|
247 |
+
try:
|
248 |
+
# Create directory if it doesn't exist
|
249 |
+
Path(config_path).parent.mkdir(parents=True, exist_ok=True)
|
250 |
+
|
251 |
+
with open(config_path, 'w') as f:
|
252 |
+
json.dump(config, f, indent=2)
|
253 |
+
|
254 |
+
logger.info(f"Saved configuration to {config_path}")
|
255 |
+
return True
|
256 |
+
except Exception as e:
|
257 |
+
logger.error(f"Failed to save configuration to {config_path}: {str(e)}")
|
258 |
+
return False
|
259 |
+
|
260 |
+
class RateLimiter:
|
261 |
+
"""Simple rate limiter for API calls"""
|
262 |
+
|
263 |
+
def __init__(self, max_calls: int, time_window: float):
|
264 |
+
self.max_calls = max_calls
|
265 |
+
self.time_window = time_window
|
266 |
+
self.calls = []
|
267 |
+
|
268 |
+
async def acquire(self):
|
269 |
+
"""Acquire permission to make a call"""
|
270 |
+
now = time.time()
|
271 |
+
|
272 |
+
# Remove old calls outside the time window
|
273 |
+
self.calls = [call_time for call_time in self.calls if now - call_time < self.time_window]
|
274 |
+
|
275 |
+
# Check if we can make a new call
|
276 |
+
if len(self.calls) >= self.max_calls:
|
277 |
+
# Wait until we can make a call
|
278 |
+
oldest_call = min(self.calls)
|
279 |
+
wait_time = self.time_window - (now - oldest_call)
|
280 |
+
if wait_time > 0:
|
281 |
+
await asyncio.sleep(wait_time)
|
282 |
+
return await self.acquire() # Recursive call after waiting
|
283 |
+
|
284 |
+
# Record this call
|
285 |
+
self.calls.append(now)
|
286 |
+
|
287 |
+
def escape_markdown(text: str) -> str:
|
288 |
+
"""Escape markdown special characters"""
|
289 |
+
import re
|
290 |
+
|
291 |
+
# Characters that need escaping in markdown
|
292 |
+
markdown_chars = r'([*_`\[\]()#+\-!\\])'
|
293 |
+
return re.sub(markdown_chars, r'\\\1', text)
|
294 |
+
|
295 |
+
def create_error_summary(errors: List[Exception]) -> str:
|
296 |
+
"""Create a summary of multiple errors"""
|
297 |
+
if not errors:
|
298 |
+
return "No errors"
|
299 |
+
|
300 |
+
error_counts = {}
|
301 |
+
for error in errors:
|
302 |
+
error_type = type(error).__name__
|
303 |
+
error_counts[error_type] = error_counts.get(error_type, 0) + 1
|
304 |
+
|
305 |
+
summary_parts = []
|
306 |
+
for error_type, count in error_counts.items():
|
307 |
+
if count == 1:
|
308 |
+
summary_parts.append(f"1 {error_type}")
|
309 |
+
else:
|
310 |
+
summary_parts.append(f"{count} {error_type}s")
|
311 |
+
|
312 |
+
return f"Encountered {len(errors)} total errors: " + ", ".join(summary_parts)
|
313 |
+
|
314 |
+
async def safe_execute(func: Callable, *args, default_return=None, **kwargs):
|
315 |
+
"""Safely execute a function and return default on error"""
|
316 |
+
try:
|
317 |
+
if asyncio.iscoroutinefunction(func):
|
318 |
+
return await func(*args, **kwargs)
|
319 |
+
else:
|
320 |
+
return func(*args, **kwargs)
|
321 |
+
except Exception as e:
|
322 |
+
logger.error(f"Error executing {func.__name__}: {str(e)}")
|
323 |
+
return default_return
|
324 |
+
|
325 |
+
def get_content_preview(content: str, max_length: int = 200) -> str:
|
326 |
+
"""Get a preview of content for display"""
|
327 |
+
if not content:
|
328 |
+
return "No content"
|
329 |
+
|
330 |
+
# Clean up whitespace
|
331 |
+
content = ' '.join(content.split())
|
332 |
+
|
333 |
+
if len(content) <= max_length:
|
334 |
+
return content
|
335 |
+
|
336 |
+
# Try to break at sentence boundary
|
337 |
+
preview = content[:max_length]
|
338 |
+
last_sentence_end = max(preview.rfind('.'), preview.rfind('!'), preview.rfind('?'))
|
339 |
+
|
340 |
+
if last_sentence_end > max_length * 0.7: # If we found a good breaking point
|
341 |
+
return preview[:last_sentence_end + 1]
|
342 |
+
else:
|
343 |
+
# Break at word boundary
|
344 |
+
last_space = preview.rfind(' ')
|
345 |
+
if last_space > max_length * 0.7:
|
346 |
+
return preview[:last_space] + "..."
|
347 |
+
else:
|
348 |
+
return preview + "..."
|
349 |
+
|
350 |
+
class MemoryUsageTracker:
|
351 |
+
"""Track memory usage of operations"""
|
352 |
+
|
353 |
+
def __init__(self):
|
354 |
+
self.start_memory = self._get_memory_usage()
|
355 |
+
|
356 |
+
def _get_memory_usage(self) -> float:
|
357 |
+
"""Get current memory usage in MB"""
|
358 |
+
try:
|
359 |
+
import psutil
|
360 |
+
process = psutil.Process()
|
361 |
+
return process.memory_info().rss / 1024 / 1024 # Convert to MB
|
362 |
+
except ImportError:
|
363 |
+
return 0.0
|
364 |
+
|
365 |
+
def get_usage_delta(self) -> float:
|
366 |
+
"""Get memory usage change since initialization"""
|
367 |
+
current_memory = self._get_memory_usage()
|
368 |
+
return current_memory - self.start_memory
|
369 |
+
|
370 |
+
def log_usage(self, operation_name: str):
|
371 |
+
"""Log current memory usage for an operation"""
|
372 |
+
delta = self.get_usage_delta()
|
373 |
+
logger.info(f"{operation_name} memory delta: {delta:.1f} MB")
|
services/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Services module initialization
|
services/document_store_service.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import List, Dict, Any, Optional
|
5 |
+
from pathlib import Path
|
6 |
+
import pickle
|
7 |
+
from datetime import datetime
|
8 |
+
import asyncio
|
9 |
+
|
10 |
+
from core.models import Document, DocumentType
|
11 |
+
import config
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
class DocumentStoreService:
|
16 |
+
def __init__(self):
|
17 |
+
self.config = config.config
|
18 |
+
self.store_path = Path(self.config.DOCUMENT_STORE_PATH)
|
19 |
+
self.store_path.mkdir(parents=True, exist_ok=True)
|
20 |
+
|
21 |
+
# Separate paths for metadata and content
|
22 |
+
self.metadata_path = self.store_path / "metadata"
|
23 |
+
self.content_path = self.store_path / "content"
|
24 |
+
|
25 |
+
self.metadata_path.mkdir(exist_ok=True)
|
26 |
+
self.content_path.mkdir(exist_ok=True)
|
27 |
+
|
28 |
+
# In-memory cache for frequently accessed documents
|
29 |
+
self._cache = {}
|
30 |
+
self._cache_size_limit = 100
|
31 |
+
|
32 |
+
async def store_document(self, document: Document) -> bool:
|
33 |
+
"""Store a document and its metadata"""
|
34 |
+
try:
|
35 |
+
# Store metadata
|
36 |
+
metadata_file = self.metadata_path / f"{document.id}.json"
|
37 |
+
metadata = {
|
38 |
+
"id": document.id,
|
39 |
+
"filename": document.filename,
|
40 |
+
"doc_type": document.doc_type.value,
|
41 |
+
"file_size": document.file_size,
|
42 |
+
"created_at": document.created_at.isoformat(),
|
43 |
+
"metadata": document.metadata,
|
44 |
+
"tags": document.tags,
|
45 |
+
"summary": document.summary,
|
46 |
+
"category": document.category,
|
47 |
+
"language": document.language,
|
48 |
+
"content_length": len(document.content)
|
49 |
+
}
|
50 |
+
|
51 |
+
with open(metadata_file, 'w', encoding='utf-8') as f:
|
52 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
53 |
+
|
54 |
+
# Store content separately (can be large)
|
55 |
+
content_file = self.content_path / f"{document.id}.txt"
|
56 |
+
with open(content_file, 'w', encoding='utf-8') as f:
|
57 |
+
f.write(document.content)
|
58 |
+
|
59 |
+
# Cache the document
|
60 |
+
self._add_to_cache(document.id, document)
|
61 |
+
|
62 |
+
logger.info(f"Stored document {document.id} ({document.filename})")
|
63 |
+
return True
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
logger.error(f"Error storing document {document.id}: {str(e)}")
|
67 |
+
return False
|
68 |
+
|
69 |
+
async def get_document(self, document_id: str) -> Optional[Document]:
|
70 |
+
"""Retrieve a document by ID"""
|
71 |
+
try:
|
72 |
+
# Check cache first
|
73 |
+
if document_id in self._cache:
|
74 |
+
return self._cache[document_id]
|
75 |
+
|
76 |
+
# Load from disk
|
77 |
+
metadata_file = self.metadata_path / f"{document_id}.json"
|
78 |
+
content_file = self.content_path / f"{document_id}.txt"
|
79 |
+
|
80 |
+
if not metadata_file.exists() or not content_file.exists():
|
81 |
+
return None
|
82 |
+
|
83 |
+
# Load metadata
|
84 |
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
85 |
+
metadata = json.load(f)
|
86 |
+
|
87 |
+
# Load content
|
88 |
+
with open(content_file, 'r', encoding='utf-8') as f:
|
89 |
+
content = f.read()
|
90 |
+
|
91 |
+
# Create document object
|
92 |
+
document = Document(
|
93 |
+
id=metadata["id"],
|
94 |
+
filename=metadata["filename"],
|
95 |
+
content=content,
|
96 |
+
doc_type=DocumentType(metadata["doc_type"]),
|
97 |
+
file_size=metadata["file_size"],
|
98 |
+
created_at=datetime.fromisoformat(metadata["created_at"]),
|
99 |
+
metadata=metadata.get("metadata", {}),
|
100 |
+
tags=metadata.get("tags", []),
|
101 |
+
summary=metadata.get("summary"),
|
102 |
+
category=metadata.get("category"),
|
103 |
+
language=metadata.get("language")
|
104 |
+
)
|
105 |
+
|
106 |
+
# Add to cache
|
107 |
+
self._add_to_cache(document_id, document)
|
108 |
+
|
109 |
+
return document
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Error retrieving document {document_id}: {str(e)}")
|
113 |
+
return None
|
114 |
+
|
115 |
+
async def list_documents(self, limit: int = 50, offset: int = 0,
|
116 |
+
filters: Optional[Dict[str, Any]] = None) -> List[Document]:
|
117 |
+
"""List documents with pagination and filtering"""
|
118 |
+
try:
|
119 |
+
documents = []
|
120 |
+
metadata_files = list(self.metadata_path.glob("*.json"))
|
121 |
+
|
122 |
+
# Sort by creation time (newest first)
|
123 |
+
metadata_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
|
124 |
+
|
125 |
+
# Apply pagination
|
126 |
+
start_idx = offset
|
127 |
+
end_idx = offset + limit
|
128 |
+
|
129 |
+
for metadata_file in metadata_files[start_idx:end_idx]:
|
130 |
+
try:
|
131 |
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
132 |
+
metadata = json.load(f)
|
133 |
+
|
134 |
+
# Apply filters
|
135 |
+
if filters and not self._apply_filters(metadata, filters):
|
136 |
+
continue
|
137 |
+
|
138 |
+
# Load content if needed (for small documents)
|
139 |
+
content_file = self.content_path / f"{metadata['id']}.txt"
|
140 |
+
if content_file.exists():
|
141 |
+
with open(content_file, 'r', encoding='utf-8') as f:
|
142 |
+
content = f.read()
|
143 |
+
else:
|
144 |
+
content = ""
|
145 |
+
|
146 |
+
document = Document(
|
147 |
+
id=metadata["id"],
|
148 |
+
filename=metadata["filename"],
|
149 |
+
content=content,
|
150 |
+
doc_type=DocumentType(metadata["doc_type"]),
|
151 |
+
file_size=metadata["file_size"],
|
152 |
+
created_at=datetime.fromisoformat(metadata["created_at"]),
|
153 |
+
metadata=metadata.get("metadata", {}),
|
154 |
+
tags=metadata.get("tags", []),
|
155 |
+
summary=metadata.get("summary"),
|
156 |
+
category=metadata.get("category"),
|
157 |
+
language=metadata.get("language")
|
158 |
+
)
|
159 |
+
|
160 |
+
documents.append(document)
|
161 |
+
|
162 |
+
except Exception as e:
|
163 |
+
logger.warning(f"Error loading document metadata from {metadata_file}: {str(e)}")
|
164 |
+
continue
|
165 |
+
|
166 |
+
return documents
|
167 |
+
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error listing documents: {str(e)}")
|
170 |
+
return []
|
171 |
+
|
172 |
+
def _apply_filters(self, metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
|
173 |
+
"""Apply filters to document metadata"""
|
174 |
+
try:
|
175 |
+
for key, value in filters.items():
|
176 |
+
if key == "doc_type":
|
177 |
+
if metadata.get("doc_type") != value:
|
178 |
+
return False
|
179 |
+
elif key == "filename_contains":
|
180 |
+
if value.lower() not in metadata.get("filename", "").lower():
|
181 |
+
return False
|
182 |
+
elif key == "created_after":
|
183 |
+
doc_date = datetime.fromisoformat(metadata.get("created_at", ""))
|
184 |
+
if doc_date < value:
|
185 |
+
return False
|
186 |
+
elif key == "created_before":
|
187 |
+
doc_date = datetime.fromisoformat(metadata.get("created_at", ""))
|
188 |
+
if doc_date > value:
|
189 |
+
return False
|
190 |
+
elif key == "tags":
|
191 |
+
doc_tags = set(metadata.get("tags", []))
|
192 |
+
required_tags = set(value) if isinstance(value, list) else {value}
|
193 |
+
if not required_tags.intersection(doc_tags):
|
194 |
+
return False
|
195 |
+
elif key == "category":
|
196 |
+
if metadata.get("category") != value:
|
197 |
+
return False
|
198 |
+
elif key == "language":
|
199 |
+
if metadata.get("language") != value:
|
200 |
+
return False
|
201 |
+
|
202 |
+
return True
|
203 |
+
except Exception as e:
|
204 |
+
logger.error(f"Error applying filters: {str(e)}")
|
205 |
+
return True
|
206 |
+
|
207 |
+
async def update_document_metadata(self, document_id: str, updates: Dict[str, Any]) -> bool:
|
208 |
+
"""Update document metadata"""
|
209 |
+
try:
|
210 |
+
metadata_file = self.metadata_path / f"{document_id}.json"
|
211 |
+
|
212 |
+
if not metadata_file.exists():
|
213 |
+
logger.warning(f"Document {document_id} not found")
|
214 |
+
return False
|
215 |
+
|
216 |
+
# Load existing metadata
|
217 |
+
with open(metadata_file, 'r', encoding='utf-8') as f:
|
218 |
+
metadata = json.load(f)
|
219 |
+
|
220 |
+
# Apply updates
|
221 |
+
for key, value in updates.items():
|
222 |
+
if key in ["tags", "summary", "category", "language", "metadata"]:
|
223 |
+
metadata[key] = value
|
224 |
+
|
225 |
+
# Save updated metadata
|
226 |
+
with open(metadata_file, 'w', encoding='utf-8') as f:
|
227 |
+
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
228 |
+
|
229 |
+
# Update cache if document is cached
|
230 |
+
if document_id in self._cache:
|
231 |
+
document = self._cache[document_id]
|
232 |
+
for key, value in updates.items():
|
233 |
+
if hasattr(document, key):
|
234 |
+
setattr(document, key, value)
|
235 |
+
|
236 |
+
logger.info(f"Updated metadata for document {document_id}")
|
237 |
+
return True
|
238 |
+
|
239 |
+
except Exception as e:
|
240 |
+
logger.error(f"Error updating document metadata: {str(e)}")
|
241 |
+
return False
|
242 |
+
|
243 |
+
async def delete_document(self, document_id: str) -> bool:
|
244 |
+
"""Delete a document and its metadata"""
|
245 |
+
try:
|
246 |
+
metadata_file = self.metadata_path / f"{document_id}.json"
|
247 |
+
content_file = self.content_path / f"{document_id}.txt"
|
248 |
+
|
249 |
+
# Remove files
|
250 |
+
if metadata_file.exists():
|
251 |
+
metadata_file.unlink()
|
252 |
+
if content_file.exists():
|
253 |
+
content_file.unlink()
|
254 |
+
|
255 |
+
# Remove from cache
|
256 |
+
if document_id in self._cache:
|
257 |
+
del self._cache[document_id]
|
258 |
+
|
259 |
+
logger.info(f"Deleted document {document_id}")
|
260 |
+
return True
|
261 |
+
|
262 |
+
except Exception as e:
|
263 |
+
logger.error(f"Error deleting document {document_id}: {str(e)}")
|
264 |
+
return False
|
265 |
+
|
266 |
+
async def search_documents(self, query: str, fields: List[str] = None) -> List[Document]:
|
267 |
+
"""Simple text search across documents"""
|
268 |
+
if not fields:
|
269 |
+
fields = ["filename", "content", "tags", "summary"]
|
270 |
+
|
271 |
+
try:
|
272 |
+
matching_documents = []
|
273 |
+
query_lower = query.lower()
|
274 |
+
|
275 |
+
# Get all documents
|
276 |
+
all_documents = await self.list_documents(limit=1000) # Adjust limit as needed
|
277 |
+
|
278 |
+
for document in all_documents:
|
279 |
+
match_found = False
|
280 |
+
|
281 |
+
for field in fields:
|
282 |
+
field_value = getattr(document, field, "")
|
283 |
+
if isinstance(field_value, list):
|
284 |
+
field_value = " ".join(field_value)
|
285 |
+
elif field_value is None:
|
286 |
+
field_value = ""
|
287 |
+
|
288 |
+
if query_lower in str(field_value).lower():
|
289 |
+
match_found = True
|
290 |
+
break
|
291 |
+
|
292 |
+
if match_found:
|
293 |
+
matching_documents.append(document)
|
294 |
+
|
295 |
+
logger.info(f"Found {len(matching_documents)} documents matching '{query}'")
|
296 |
+
return matching_documents
|
297 |
+
|
298 |
+
except Exception as e:
|
299 |
+
logger.error(f"Error searching documents: {str(e)}")
|
300 |
+
return []
|
301 |
+
|
302 |
+
def _add_to_cache(self, document_id: str, document: Document):
|
303 |
+
"""Add document to cache with size limit"""
|
304 |
+
try:
|
305 |
+
# Remove oldest items if cache is full
|
306 |
+
if len(self._cache) >= self._cache_size_limit:
|
307 |
+
# Remove first item (FIFO)
|
308 |
+
oldest_key = next(iter(self._cache))
|
309 |
+
del self._cache[oldest_key]
|
310 |
+
|
311 |
+
self._cache[document_id] = document
|
312 |
+
except Exception as e:
|
313 |
+
logger.error(f"Error adding to cache: {str(e)}")
|
314 |
+
|
315 |
+
async def get_stats(self) -> Dict[str, Any]:
|
316 |
+
"""Get statistics about the document store"""
|
317 |
+
try:
|
318 |
+
metadata_files = list(self.metadata_path.glob("*.json"))
|
319 |
+
content_files = list(self.content_path.glob("*.txt"))
|
320 |
+
|
321 |
+
# Calculate total storage size
|
322 |
+
total_size = 0
|
323 |
+
for file_path in metadata_files + content_files:
|
324 |
+
total_size += file_path.stat().st_size
|
325 |
+
|
326 |
+
# Count by document type
|
327 |
+
type_counts = {}
|
328 |
+
for metadata_file in metadata_files:
|
329 |
+
try:
|
330 |
+
with open(metadata_file, 'r') as f:
|
331 |
+
metadata = json.load(f)
|
332 |
+
doc_type = metadata.get("doc_type", "unknown")
|
333 |
+
type_counts[doc_type] = type_counts.get(doc_type, 0) + 1
|
334 |
+
except:
|
335 |
+
continue
|
336 |
+
|
337 |
+
return {
|
338 |
+
"total_documents": len(metadata_files),
|
339 |
+
"total_size_bytes": total_size,
|
340 |
+
"total_size_mb": round(total_size / (1024 * 1024), 2),
|
341 |
+
"cache_size": len(self._cache),
|
342 |
+
"document_types": type_counts,
|
343 |
+
"storage_path": str(self.store_path),
|
344 |
+
"metadata_files": len(metadata_files),
|
345 |
+
"content_files": len(content_files)
|
346 |
+
}
|
347 |
+
except Exception as e:
|
348 |
+
logger.error(f"Error getting document store stats: {str(e)}")
|
349 |
+
return {"error": str(e)}
|
services/embedding_service.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
from typing import List, Optional, Dict, Any
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import torch
|
7 |
+
import config
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
class EmbeddingService:
|
12 |
+
def __init__(self):
|
13 |
+
self.config = config.config
|
14 |
+
self.model_name = self.config.EMBEDDING_MODEL
|
15 |
+
self.model = None
|
16 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
17 |
+
|
18 |
+
# Load model lazily
|
19 |
+
self._load_model()
|
20 |
+
|
21 |
+
def _load_model(self):
|
22 |
+
"""Load the embedding model"""
|
23 |
+
try:
|
24 |
+
logger.info(f"Loading embedding model: {self.model_name}")
|
25 |
+
self.model = SentenceTransformer(self.model_name, device=self.device)
|
26 |
+
logger.info(f"Embedding model loaded successfully on {self.device}")
|
27 |
+
except Exception as e:
|
28 |
+
logger.error(f"Failed to load embedding model: {str(e)}")
|
29 |
+
# Fallback to a smaller model
|
30 |
+
try:
|
31 |
+
self.model_name = "all-MiniLM-L6-v2"
|
32 |
+
self.model = SentenceTransformer(self.model_name, device=self.device)
|
33 |
+
logger.info(f"Loaded fallback embedding model: {self.model_name}")
|
34 |
+
except Exception as fallback_error:
|
35 |
+
logger.error(f"Failed to load fallback model: {str(fallback_error)}")
|
36 |
+
raise
|
37 |
+
|
38 |
+
async def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> List[List[float]]:
|
39 |
+
"""Generate embeddings for a list of texts"""
|
40 |
+
if not texts:
|
41 |
+
return []
|
42 |
+
|
43 |
+
if self.model is None:
|
44 |
+
raise RuntimeError("Embedding model not loaded")
|
45 |
+
|
46 |
+
try:
|
47 |
+
# Filter out empty texts
|
48 |
+
non_empty_texts = [text for text in texts if text and text.strip()]
|
49 |
+
if not non_empty_texts:
|
50 |
+
logger.warning("No non-empty texts provided for embedding")
|
51 |
+
return []
|
52 |
+
|
53 |
+
logger.info(f"Generating embeddings for {len(non_empty_texts)} texts")
|
54 |
+
|
55 |
+
# Process in batches to manage memory
|
56 |
+
all_embeddings = []
|
57 |
+
for i in range(0, len(non_empty_texts), batch_size):
|
58 |
+
batch = non_empty_texts[i:i + batch_size]
|
59 |
+
|
60 |
+
# Run embedding generation in thread pool to avoid blocking
|
61 |
+
loop = asyncio.get_event_loop()
|
62 |
+
batch_embeddings = await loop.run_in_executor(
|
63 |
+
None,
|
64 |
+
self._generate_batch_embeddings,
|
65 |
+
batch
|
66 |
+
)
|
67 |
+
all_embeddings.extend(batch_embeddings)
|
68 |
+
|
69 |
+
logger.info(f"Generated {len(all_embeddings)} embeddings")
|
70 |
+
return all_embeddings
|
71 |
+
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"Error generating embeddings: {str(e)}")
|
74 |
+
raise
|
75 |
+
|
76 |
+
def _generate_batch_embeddings(self, texts: List[str]) -> List[List[float]]:
|
77 |
+
"""Generate embeddings for a batch of texts (synchronous)"""
|
78 |
+
try:
|
79 |
+
# Generate embeddings
|
80 |
+
embeddings = self.model.encode(
|
81 |
+
texts,
|
82 |
+
convert_to_numpy=True,
|
83 |
+
normalize_embeddings=True,
|
84 |
+
batch_size=len(texts)
|
85 |
+
)
|
86 |
+
|
87 |
+
# Convert to list of lists
|
88 |
+
return embeddings.tolist()
|
89 |
+
except Exception as e:
|
90 |
+
logger.error(f"Error in batch embedding generation: {str(e)}")
|
91 |
+
raise
|
92 |
+
|
93 |
+
async def generate_single_embedding(self, text: str) -> Optional[List[float]]:
|
94 |
+
"""Generate embedding for a single text"""
|
95 |
+
if not text or not text.strip():
|
96 |
+
return None
|
97 |
+
|
98 |
+
try:
|
99 |
+
embeddings = await self.generate_embeddings([text])
|
100 |
+
return embeddings[0] if embeddings else None
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error generating single embedding: {str(e)}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
def get_embedding_dimension(self) -> int:
|
106 |
+
"""Get the dimension of embeddings produced by the model"""
|
107 |
+
if self.model is None:
|
108 |
+
raise RuntimeError("Embedding model not loaded")
|
109 |
+
|
110 |
+
return self.model.get_sentence_embedding_dimension()
|
111 |
+
|
112 |
+
def compute_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
113 |
+
"""Compute cosine similarity between two embeddings"""
|
114 |
+
try:
|
115 |
+
# Convert to numpy arrays
|
116 |
+
emb1 = np.array(embedding1)
|
117 |
+
emb2 = np.array(embedding2)
|
118 |
+
|
119 |
+
# Compute cosine similarity
|
120 |
+
similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
|
121 |
+
|
122 |
+
return float(similarity)
|
123 |
+
except Exception as e:
|
124 |
+
logger.error(f"Error computing similarity: {str(e)}")
|
125 |
+
return 0.0
|
126 |
+
|
127 |
+
def compute_similarities(self, query_embedding: List[float], embeddings: List[List[float]]) -> List[float]:
|
128 |
+
"""Compute similarities between a query embedding and multiple embeddings"""
|
129 |
+
try:
|
130 |
+
query_emb = np.array(query_embedding)
|
131 |
+
emb_matrix = np.array(embeddings)
|
132 |
+
|
133 |
+
# Compute cosine similarities
|
134 |
+
similarities = np.dot(emb_matrix, query_emb) / (
|
135 |
+
np.linalg.norm(emb_matrix, axis=1) * np.linalg.norm(query_emb)
|
136 |
+
)
|
137 |
+
|
138 |
+
return similarities.tolist()
|
139 |
+
except Exception as e:
|
140 |
+
logger.error(f"Error computing similarities: {str(e)}")
|
141 |
+
return [0.0] * len(embeddings)
|
142 |
+
|
143 |
+
async def embed_chunks(self, chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
144 |
+
"""Embed a list of chunks and add embeddings to them"""
|
145 |
+
if not chunks:
|
146 |
+
return []
|
147 |
+
|
148 |
+
try:
|
149 |
+
# Extract texts
|
150 |
+
texts = [chunk.get('content', '') for chunk in chunks]
|
151 |
+
|
152 |
+
# Generate embeddings
|
153 |
+
embeddings = await self.generate_embeddings(texts)
|
154 |
+
|
155 |
+
# Add embeddings to chunks
|
156 |
+
embedded_chunks = []
|
157 |
+
for i, chunk in enumerate(chunks):
|
158 |
+
if i < len(embeddings):
|
159 |
+
chunk_copy = chunk.copy()
|
160 |
+
chunk_copy['embedding'] = embeddings[i]
|
161 |
+
embedded_chunks.append(chunk_copy)
|
162 |
+
else:
|
163 |
+
logger.warning(f"No embedding generated for chunk {i}")
|
164 |
+
embedded_chunks.append(chunk)
|
165 |
+
|
166 |
+
return embedded_chunks
|
167 |
+
except Exception as e:
|
168 |
+
logger.error(f"Error embedding chunks: {str(e)}")
|
169 |
+
raise
|
170 |
+
|
171 |
+
def validate_embedding(self, embedding: List[float]) -> bool:
|
172 |
+
"""Validate that an embedding is properly formatted"""
|
173 |
+
try:
|
174 |
+
if not embedding:
|
175 |
+
return False
|
176 |
+
|
177 |
+
if not isinstance(embedding, list):
|
178 |
+
return False
|
179 |
+
|
180 |
+
if len(embedding) != self.get_embedding_dimension():
|
181 |
+
return False
|
182 |
+
|
183 |
+
# Check for NaN or infinite values
|
184 |
+
emb_array = np.array(embedding)
|
185 |
+
if np.isnan(emb_array).any() or np.isinf(emb_array).any():
|
186 |
+
return False
|
187 |
+
|
188 |
+
return True
|
189 |
+
except Exception:
|
190 |
+
return False
|
191 |
+
|
192 |
+
async def get_model_info(self) -> Dict[str, Any]:
|
193 |
+
"""Get information about the loaded model"""
|
194 |
+
try:
|
195 |
+
return {
|
196 |
+
"model_name": self.model_name,
|
197 |
+
"device": self.device,
|
198 |
+
"embedding_dimension": self.get_embedding_dimension(),
|
199 |
+
"max_sequence_length": getattr(self.model, 'max_seq_length', 'unknown'),
|
200 |
+
"model_loaded": self.model is not None
|
201 |
+
}
|
202 |
+
except Exception as e:
|
203 |
+
logger.error(f"Error getting model info: {str(e)}")
|
204 |
+
return {"error": str(e)}
|
services/llm_service.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import asyncio
|
3 |
+
from typing import List, Dict, Any, Optional
|
4 |
+
import anthropic
|
5 |
+
from mistralai.client import MistralClient
|
6 |
+
import config
|
7 |
+
|
8 |
+
logger = logging.getLogger(__name__)
|
9 |
+
|
10 |
+
class LLMService:
|
11 |
+
def __init__(self):
|
12 |
+
self.config = config.config
|
13 |
+
|
14 |
+
# Initialize clients
|
15 |
+
self.anthropic_client = None
|
16 |
+
self.mistral_client = None
|
17 |
+
|
18 |
+
self._initialize_clients()
|
19 |
+
|
20 |
+
def _initialize_clients(self):
|
21 |
+
"""Initialize LLM clients"""
|
22 |
+
try:
|
23 |
+
if self.config.ANTHROPIC_API_KEY:
|
24 |
+
self.anthropic_client = anthropic.Anthropic(
|
25 |
+
api_key=self.config.ANTHROPIC_API_KEY
|
26 |
+
)
|
27 |
+
logger.info("Anthropic client initialized")
|
28 |
+
|
29 |
+
if self.config.MISTRAL_API_KEY:
|
30 |
+
self.mistral_client = MistralClient(
|
31 |
+
api_key=self.config.MISTRAL_API_KEY
|
32 |
+
)
|
33 |
+
logger.info("Mistral client initialized")
|
34 |
+
|
35 |
+
if not self.anthropic_client and not self.mistral_client:
|
36 |
+
raise ValueError("No LLM clients could be initialized. Check API keys.")
|
37 |
+
|
38 |
+
except Exception as e:
|
39 |
+
logger.error(f"Error initializing LLM clients: {str(e)}")
|
40 |
+
raise
|
41 |
+
|
42 |
+
async def generate_text(self, prompt: str, model: str = "auto", max_tokens: int = 1000, temperature: float = 0.7) -> str:
|
43 |
+
"""Generate text using the specified model"""
|
44 |
+
try:
|
45 |
+
if model == "auto":
|
46 |
+
# Use Claude if available, otherwise Mistral
|
47 |
+
if self.anthropic_client:
|
48 |
+
return await self._generate_with_claude(prompt, max_tokens, temperature)
|
49 |
+
elif self.mistral_client:
|
50 |
+
return await self._generate_with_mistral(prompt, max_tokens, temperature)
|
51 |
+
else:
|
52 |
+
raise ValueError("No LLM clients available")
|
53 |
+
elif model.startswith("claude"):
|
54 |
+
if not self.anthropic_client:
|
55 |
+
raise ValueError("Anthropic client not available")
|
56 |
+
return await self._generate_with_claude(prompt, max_tokens, temperature)
|
57 |
+
elif model.startswith("mistral"):
|
58 |
+
if not self.mistral_client:
|
59 |
+
raise ValueError("Mistral client not available")
|
60 |
+
return await self._generate_with_mistral(prompt, max_tokens, temperature)
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported model: {model}")
|
63 |
+
except Exception as e:
|
64 |
+
logger.error(f"Error generating text: {str(e)}")
|
65 |
+
raise
|
66 |
+
|
67 |
+
async def _generate_with_claude(self, prompt: str, max_tokens: int, temperature: float) -> str:
|
68 |
+
"""Generate text using Claude"""
|
69 |
+
try:
|
70 |
+
loop = asyncio.get_event_loop()
|
71 |
+
response = await loop.run_in_executor(
|
72 |
+
None,
|
73 |
+
lambda: self.anthropic_client.messages.create(
|
74 |
+
model=self.config.ANTHROPIC_MODEL,
|
75 |
+
max_tokens=max_tokens,
|
76 |
+
temperature=temperature,
|
77 |
+
messages=[
|
78 |
+
{"role": "user", "content": prompt}
|
79 |
+
]
|
80 |
+
)
|
81 |
+
)
|
82 |
+
|
83 |
+
return response.content[0].text
|
84 |
+
except Exception as e:
|
85 |
+
logger.error(f"Error with Claude generation: {str(e)}")
|
86 |
+
raise
|
87 |
+
|
88 |
+
async def _generate_with_mistral(self, prompt: str, max_tokens: int, temperature: float) -> str:
|
89 |
+
"""Generate text using Mistral"""
|
90 |
+
try:
|
91 |
+
loop = asyncio.get_event_loop()
|
92 |
+
response = await loop.run_in_executor(
|
93 |
+
None,
|
94 |
+
lambda: self.mistral_client.chat(
|
95 |
+
model=self.config.MISTRAL_MODEL,
|
96 |
+
messages=[{"role": "user", "content": prompt}],
|
97 |
+
max_tokens=max_tokens,
|
98 |
+
temperature=temperature
|
99 |
+
)
|
100 |
+
)
|
101 |
+
|
102 |
+
return response.choices[0].message.content
|
103 |
+
except Exception as e:
|
104 |
+
logger.error(f"Error with Mistral generation: {str(e)}")
|
105 |
+
raise
|
106 |
+
|
107 |
+
async def summarize(self, text: str, style: str = "concise", max_length: Optional[int] = None) -> str:
|
108 |
+
"""Generate a summary of the given text"""
|
109 |
+
if not text.strip():
|
110 |
+
return ""
|
111 |
+
|
112 |
+
# Create style-specific prompts
|
113 |
+
style_prompts = {
|
114 |
+
"concise": "Provide a concise summary of the following text, focusing on the main points:",
|
115 |
+
"detailed": "Provide a detailed summary of the following text, including key details and supporting information:",
|
116 |
+
"bullet_points": "Summarize the following text as a list of bullet points highlighting the main ideas:",
|
117 |
+
"executive": "Provide an executive summary of the following text, focusing on key findings and actionable insights:"
|
118 |
+
}
|
119 |
+
|
120 |
+
prompt_template = style_prompts.get(style, style_prompts["concise"])
|
121 |
+
|
122 |
+
if max_length:
|
123 |
+
prompt_template += f" Keep the summary under {max_length} words."
|
124 |
+
|
125 |
+
prompt = f"{prompt_template}\n\nText to summarize:\n{text}\n\nSummary:"
|
126 |
+
|
127 |
+
try:
|
128 |
+
summary = await self.generate_text(prompt, max_tokens=500, temperature=0.3)
|
129 |
+
return summary.strip()
|
130 |
+
except Exception as e:
|
131 |
+
logger.error(f"Error generating summary: {str(e)}")
|
132 |
+
return "Error generating summary"
|
133 |
+
|
134 |
+
async def generate_tags(self, text: str, max_tags: int = 5) -> List[str]:
|
135 |
+
"""Generate relevant tags for the given text"""
|
136 |
+
if not text.strip():
|
137 |
+
return []
|
138 |
+
|
139 |
+
prompt = f"""Generate {max_tags} relevant tags for the following text.
|
140 |
+
Tags should be concise, descriptive keywords or phrases that capture the main topics, themes, or concepts.
|
141 |
+
Return only the tags, separated by commas.
|
142 |
+
|
143 |
+
Text:
|
144 |
+
{text}
|
145 |
+
|
146 |
+
Tags:"""
|
147 |
+
|
148 |
+
try:
|
149 |
+
response = await self.generate_text(prompt, max_tokens=100, temperature=0.5)
|
150 |
+
|
151 |
+
# Parse tags from response
|
152 |
+
tags = [tag.strip() for tag in response.split(',')]
|
153 |
+
tags = [tag for tag in tags if tag and len(tag) > 1]
|
154 |
+
|
155 |
+
return tags[:max_tags]
|
156 |
+
except Exception as e:
|
157 |
+
logger.error(f"Error generating tags: {str(e)}")
|
158 |
+
return []
|
159 |
+
|
160 |
+
async def categorize(self, text: str, categories: List[str]) -> str:
|
161 |
+
"""Categorize text into one of the provided categories"""
|
162 |
+
if not text.strip() or not categories:
|
163 |
+
return "Uncategorized"
|
164 |
+
|
165 |
+
categories_str = ", ".join(categories)
|
166 |
+
|
167 |
+
prompt = f"""Classify the following text into one of these categories: {categories_str}
|
168 |
+
|
169 |
+
Choose the most appropriate category based on the content and main theme of the text.
|
170 |
+
Return only the category name, nothing else.
|
171 |
+
|
172 |
+
Text to classify:
|
173 |
+
{text}
|
174 |
+
|
175 |
+
Category:"""
|
176 |
+
|
177 |
+
try:
|
178 |
+
response = await self.generate_text(prompt, max_tokens=50, temperature=0.1)
|
179 |
+
category = response.strip()
|
180 |
+
|
181 |
+
# Validate that the response is one of the provided categories
|
182 |
+
if category in categories:
|
183 |
+
return category
|
184 |
+
else:
|
185 |
+
# Try to find a close match
|
186 |
+
category_lower = category.lower()
|
187 |
+
for cat in categories:
|
188 |
+
if cat.lower() in category_lower or category_lower in cat.lower():
|
189 |
+
return cat
|
190 |
+
|
191 |
+
return categories[0] if categories else "Uncategorized"
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"Error categorizing text: {str(e)}")
|
194 |
+
return "Uncategorized"
|
195 |
+
|
196 |
+
async def answer_question(self, question: str, context: str, max_context_length: int = 2000) -> str:
|
197 |
+
"""Answer a question based on the provided context"""
|
198 |
+
if not question.strip():
|
199 |
+
return "No question provided"
|
200 |
+
|
201 |
+
if not context.strip():
|
202 |
+
return "I don't have enough context to answer this question. Please provide more relevant information."
|
203 |
+
|
204 |
+
# Truncate context if too long
|
205 |
+
if len(context) > max_context_length:
|
206 |
+
context = context[:max_context_length] + "..."
|
207 |
+
|
208 |
+
prompt = f"""Based on the following context, answer the question. If the context doesn't contain enough information to answer the question completely, say so and provide what information you can.
|
209 |
+
|
210 |
+
Context:
|
211 |
+
{context}
|
212 |
+
|
213 |
+
Question: {question}
|
214 |
+
|
215 |
+
Answer:"""
|
216 |
+
|
217 |
+
try:
|
218 |
+
answer = await self.generate_text(prompt, max_tokens=300, temperature=0.3)
|
219 |
+
return answer.strip()
|
220 |
+
except Exception as e:
|
221 |
+
logger.error(f"Error answering question: {str(e)}")
|
222 |
+
return "I encountered an error while trying to answer your question."
|
223 |
+
|
224 |
+
async def extract_key_information(self, text: str) -> Dict[str, Any]:
|
225 |
+
"""Extract key information from text"""
|
226 |
+
if not text.strip():
|
227 |
+
return {}
|
228 |
+
|
229 |
+
prompt = f"""Analyze the following text and extract key information. Provide the response in the following format:
|
230 |
+
|
231 |
+
Main Topic: [main topic or subject]
|
232 |
+
Key Points: [list 3-5 key points]
|
233 |
+
Entities: [important people, places, organizations mentioned]
|
234 |
+
Sentiment: [positive/neutral/negative]
|
235 |
+
Content Type: [article/document/email/report/etc.]
|
236 |
+
|
237 |
+
Text to analyze:
|
238 |
+
{text}
|
239 |
+
|
240 |
+
Analysis:"""
|
241 |
+
|
242 |
+
try:
|
243 |
+
response = await self.generate_text(prompt, max_tokens=400, temperature=0.4)
|
244 |
+
|
245 |
+
# Parse the structured response
|
246 |
+
info = {}
|
247 |
+
lines = response.split('\n')
|
248 |
+
|
249 |
+
for line in lines:
|
250 |
+
if ':' in line:
|
251 |
+
key, value = line.split(':', 1)
|
252 |
+
key = key.strip().lower().replace(' ', '_')
|
253 |
+
value = value.strip()
|
254 |
+
if value:
|
255 |
+
info[key] = value
|
256 |
+
|
257 |
+
return info
|
258 |
+
except Exception as e:
|
259 |
+
logger.error(f"Error extracting key information: {str(e)}")
|
260 |
+
return {}
|
261 |
+
|
262 |
+
async def check_availability(self) -> Dict[str, bool]:
|
263 |
+
"""Check which LLM services are available"""
|
264 |
+
availability = {
|
265 |
+
"anthropic": False,
|
266 |
+
"mistral": False
|
267 |
+
}
|
268 |
+
|
269 |
+
try:
|
270 |
+
if self.anthropic_client:
|
271 |
+
# Test Claude availability with a simple request
|
272 |
+
test_response = await self._generate_with_claude("Hello", 10, 0.1)
|
273 |
+
availability["anthropic"] = bool(test_response)
|
274 |
+
except:
|
275 |
+
pass
|
276 |
+
|
277 |
+
try:
|
278 |
+
if self.mistral_client:
|
279 |
+
# Test Mistral availability with a simple request
|
280 |
+
test_response = await self._generate_with_mistral("Hello", 10, 0.1)
|
281 |
+
availability["mistral"] = bool(test_response)
|
282 |
+
except:
|
283 |
+
pass
|
284 |
+
|
285 |
+
return availability
|
services/ocr_service.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional, List, Dict, Any
|
3 |
+
import asyncio
|
4 |
+
from pathlib import Path
|
5 |
+
import tempfile
|
6 |
+
import os
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
import pytesseract
|
10 |
+
import config
|
11 |
+
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
+
|
14 |
+
class OCRService:
|
15 |
+
def __init__(self):
|
16 |
+
self.config = config.config
|
17 |
+
|
18 |
+
# Configure Tesseract path if specified
|
19 |
+
if self.config.TESSERACT_PATH:
|
20 |
+
pytesseract.pytesseract.tesseract_cmd = self.config.TESSERACT_PATH
|
21 |
+
|
22 |
+
self.language = self.config.OCR_LANGUAGE
|
23 |
+
|
24 |
+
# Test OCR availability
|
25 |
+
self._test_ocr_availability()
|
26 |
+
|
27 |
+
def _test_ocr_availability(self):
|
28 |
+
"""Test if OCR is available and working"""
|
29 |
+
try:
|
30 |
+
# Create a simple test image
|
31 |
+
test_image = Image.new('RGB', (100, 30), color='white')
|
32 |
+
pytesseract.image_to_string(test_image)
|
33 |
+
logger.info("OCR service initialized successfully")
|
34 |
+
except Exception as e:
|
35 |
+
logger.warning(f"OCR may not be available: {str(e)}")
|
36 |
+
|
37 |
+
async def extract_text_from_image(self, image_path: str, language: Optional[str] = None) -> str:
|
38 |
+
"""Extract text from an image file"""
|
39 |
+
try:
|
40 |
+
# Use specified language or default
|
41 |
+
lang = language or self.language
|
42 |
+
|
43 |
+
# Load image
|
44 |
+
image = Image.open(image_path)
|
45 |
+
|
46 |
+
# Perform OCR in thread pool to avoid blocking
|
47 |
+
loop = asyncio.get_event_loop()
|
48 |
+
text = await loop.run_in_executor(
|
49 |
+
None,
|
50 |
+
self._extract_text_sync,
|
51 |
+
image,
|
52 |
+
lang
|
53 |
+
)
|
54 |
+
|
55 |
+
return text.strip()
|
56 |
+
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"Error extracting text from image {image_path}: {str(e)}")
|
59 |
+
return ""
|
60 |
+
|
61 |
+
def _extract_text_sync(self, image: Image.Image, language: str) -> str:
|
62 |
+
"""Synchronous text extraction"""
|
63 |
+
try:
|
64 |
+
# Optimize image for OCR
|
65 |
+
processed_image = self._preprocess_image(image)
|
66 |
+
|
67 |
+
# Configure OCR
|
68 |
+
config_string = '--psm 6' # Assume a single uniform block of text
|
69 |
+
|
70 |
+
# Extract text
|
71 |
+
text = pytesseract.image_to_string(
|
72 |
+
processed_image,
|
73 |
+
lang=language,
|
74 |
+
config=config_string
|
75 |
+
)
|
76 |
+
|
77 |
+
return text
|
78 |
+
except Exception as e:
|
79 |
+
logger.error(f"Error in synchronous OCR: {str(e)}")
|
80 |
+
return ""
|
81 |
+
|
82 |
+
def _preprocess_image(self, image: Image.Image) -> Image.Image:
|
83 |
+
"""Preprocess image to improve OCR accuracy"""
|
84 |
+
try:
|
85 |
+
# Convert to grayscale if not already
|
86 |
+
if image.mode != 'L':
|
87 |
+
image = image.convert('L')
|
88 |
+
|
89 |
+
# Resize image if too small (OCR works better on larger images)
|
90 |
+
width, height = image.size
|
91 |
+
if width < 300 or height < 300:
|
92 |
+
scale_factor = max(300 / width, 300 / height)
|
93 |
+
new_width = int(width * scale_factor)
|
94 |
+
new_height = int(height * scale_factor)
|
95 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
96 |
+
|
97 |
+
return image
|
98 |
+
except Exception as e:
|
99 |
+
logger.error(f"Error preprocessing image: {str(e)}")
|
100 |
+
return image
|
101 |
+
|
102 |
+
async def extract_text_from_pdf_images(self, pdf_path: str) -> List[str]:
|
103 |
+
"""Extract text from PDF by converting pages to images and running OCR"""
|
104 |
+
try:
|
105 |
+
import fitz # PyMuPDF
|
106 |
+
|
107 |
+
texts = []
|
108 |
+
|
109 |
+
# Open PDF
|
110 |
+
pdf_document = fitz.open(pdf_path)
|
111 |
+
|
112 |
+
for page_num in range(len(pdf_document)):
|
113 |
+
try:
|
114 |
+
# Get page
|
115 |
+
page = pdf_document[page_num]
|
116 |
+
|
117 |
+
# Convert page to image
|
118 |
+
mat = fitz.Matrix(2.0, 2.0) # Scale factor for better quality
|
119 |
+
pix = page.get_pixmap(matrix=mat)
|
120 |
+
img_data = pix.tobytes("ppm")
|
121 |
+
|
122 |
+
# Create PIL image from bytes
|
123 |
+
with tempfile.NamedTemporaryFile(suffix='.ppm', delete=False) as tmp_file:
|
124 |
+
tmp_file.write(img_data)
|
125 |
+
tmp_file.flush()
|
126 |
+
|
127 |
+
# Extract text from image
|
128 |
+
page_text = await self.extract_text_from_image(tmp_file.name)
|
129 |
+
texts.append(page_text)
|
130 |
+
|
131 |
+
# Clean up temporary file
|
132 |
+
os.unlink(tmp_file.name)
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
logger.warning(f"Error processing PDF page {page_num}: {str(e)}")
|
136 |
+
texts.append("")
|
137 |
+
|
138 |
+
pdf_document.close()
|
139 |
+
return texts
|
140 |
+
|
141 |
+
except ImportError:
|
142 |
+
logger.error("PyMuPDF not available for PDF OCR")
|
143 |
+
return []
|
144 |
+
except Exception as e:
|
145 |
+
logger.error(f"Error extracting text from PDF images: {str(e)}")
|
146 |
+
return []
|
147 |
+
|
148 |
+
async def extract_text_with_confidence(self, image_path: str, min_confidence: float = 0.5) -> Dict[str, Any]:
|
149 |
+
"""Extract text with confidence scores"""
|
150 |
+
try:
|
151 |
+
image = Image.open(image_path)
|
152 |
+
|
153 |
+
# Get detailed OCR data with confidence scores
|
154 |
+
loop = asyncio.get_event_loop()
|
155 |
+
ocr_data = await loop.run_in_executor(
|
156 |
+
None,
|
157 |
+
self._extract_detailed_data,
|
158 |
+
image
|
159 |
+
)
|
160 |
+
|
161 |
+
# Filter by confidence
|
162 |
+
filtered_text = []
|
163 |
+
word_confidences = []
|
164 |
+
|
165 |
+
for i, confidence in enumerate(ocr_data.get('conf', [])):
|
166 |
+
if confidence > min_confidence * 100: # Tesseract uses 0-100 scale
|
167 |
+
text = ocr_data.get('text', [])[i]
|
168 |
+
if text.strip():
|
169 |
+
filtered_text.append(text)
|
170 |
+
word_confidences.append(confidence / 100.0) # Convert to 0-1 scale
|
171 |
+
|
172 |
+
return {
|
173 |
+
"text": " ".join(filtered_text),
|
174 |
+
"confidence": sum(word_confidences) / len(word_confidences) if word_confidences else 0.0,
|
175 |
+
"word_count": len(filtered_text),
|
176 |
+
"raw_data": ocr_data
|
177 |
+
}
|
178 |
+
|
179 |
+
except Exception as e:
|
180 |
+
logger.error(f"Error extracting text with confidence: {str(e)}")
|
181 |
+
return {
|
182 |
+
"text": "",
|
183 |
+
"confidence": 0.0,
|
184 |
+
"word_count": 0,
|
185 |
+
"error": str(e)
|
186 |
+
}
|
187 |
+
|
188 |
+
def _extract_detailed_data(self, image: Image.Image) -> Dict[str, Any]:
|
189 |
+
"""Extract detailed OCR data with positions and confidence"""
|
190 |
+
try:
|
191 |
+
processed_image = self._preprocess_image(image)
|
192 |
+
|
193 |
+
# Get detailed data
|
194 |
+
data = pytesseract.image_to_data(
|
195 |
+
processed_image,
|
196 |
+
lang=self.language,
|
197 |
+
config='--psm 6',
|
198 |
+
output_type=pytesseract.Output.DICT
|
199 |
+
)
|
200 |
+
|
201 |
+
return data
|
202 |
+
except Exception as e:
|
203 |
+
logger.error(f"Error extracting detailed OCR data: {str(e)}")
|
204 |
+
return {}
|
205 |
+
|
206 |
+
async def detect_language(self, image_path: str) -> str:
|
207 |
+
"""Detect the language of text in an image"""
|
208 |
+
try:
|
209 |
+
image = Image.open(image_path)
|
210 |
+
|
211 |
+
# Run language detection
|
212 |
+
loop = asyncio.get_event_loop()
|
213 |
+
languages = await loop.run_in_executor(
|
214 |
+
None,
|
215 |
+
pytesseract.image_to_osd,
|
216 |
+
image
|
217 |
+
)
|
218 |
+
|
219 |
+
# Parse the output to get the language
|
220 |
+
for line in languages.split('\n'):
|
221 |
+
if 'Script:' in line:
|
222 |
+
script = line.split(':')[1].strip()
|
223 |
+
# Map script to language code
|
224 |
+
script_to_lang = {
|
225 |
+
'Latin': 'eng',
|
226 |
+
'Arabic': 'ara',
|
227 |
+
'Chinese': 'chi_sim',
|
228 |
+
'Japanese': 'jpn',
|
229 |
+
'Korean': 'kor'
|
230 |
+
}
|
231 |
+
return script_to_lang.get(script, 'eng')
|
232 |
+
|
233 |
+
return 'eng' # Default to English
|
234 |
+
|
235 |
+
except Exception as e:
|
236 |
+
logger.error(f"Error detecting language: {str(e)}")
|
237 |
+
return 'eng'
|
238 |
+
|
239 |
+
async def extract_tables_from_image(self, image_path: str) -> List[List[str]]:
|
240 |
+
"""Extract table data from an image"""
|
241 |
+
try:
|
242 |
+
# This is a basic implementation
|
243 |
+
# For better table extraction, consider using specialized libraries like table-transformer
|
244 |
+
|
245 |
+
image = Image.open(image_path)
|
246 |
+
|
247 |
+
# Use specific PSM for tables
|
248 |
+
loop = asyncio.get_event_loop()
|
249 |
+
text = await loop.run_in_executor(
|
250 |
+
None,
|
251 |
+
lambda: pytesseract.image_to_string(
|
252 |
+
image,
|
253 |
+
lang=self.language,
|
254 |
+
config='--psm 6 -c preserve_interword_spaces=1'
|
255 |
+
)
|
256 |
+
)
|
257 |
+
|
258 |
+
# Simple table parsing (assumes space/tab separated)
|
259 |
+
lines = text.split('\n')
|
260 |
+
table_data = []
|
261 |
+
|
262 |
+
for line in lines:
|
263 |
+
if line.strip():
|
264 |
+
# Split by multiple spaces or tabs
|
265 |
+
cells = [cell.strip() for cell in line.split() if cell.strip()]
|
266 |
+
if cells:
|
267 |
+
table_data.append(cells)
|
268 |
+
|
269 |
+
return table_data
|
270 |
+
|
271 |
+
except Exception as e:
|
272 |
+
logger.error(f"Error extracting tables from image: {str(e)}")
|
273 |
+
return []
|
274 |
+
|
275 |
+
async def get_supported_languages(self) -> List[str]:
|
276 |
+
"""Get list of supported OCR languages"""
|
277 |
+
try:
|
278 |
+
languages = pytesseract.get_languages()
|
279 |
+
return sorted(languages)
|
280 |
+
except Exception as e:
|
281 |
+
logger.error(f"Error getting supported languages: {str(e)}")
|
282 |
+
return ['eng'] # Default to English only
|
283 |
+
|
284 |
+
async def validate_ocr_setup(self) -> Dict[str, Any]:
|
285 |
+
"""Validate OCR setup and return status"""
|
286 |
+
try:
|
287 |
+
# Test basic functionality
|
288 |
+
test_image = Image.new('RGB', (200, 50), color='white')
|
289 |
+
|
290 |
+
from PIL import ImageDraw, ImageFont
|
291 |
+
draw = ImageDraw.Draw(test_image)
|
292 |
+
|
293 |
+
try:
|
294 |
+
# Try to use a default font
|
295 |
+
draw.text((10, 10), "Test OCR", fill='black')
|
296 |
+
except:
|
297 |
+
# Fall back to basic text without font
|
298 |
+
draw.text((10, 10), "Test", fill='black')
|
299 |
+
|
300 |
+
# Test OCR
|
301 |
+
result = pytesseract.image_to_string(test_image)
|
302 |
+
|
303 |
+
# Get available languages
|
304 |
+
languages = await self.get_supported_languages()
|
305 |
+
|
306 |
+
return {
|
307 |
+
"status": "operational",
|
308 |
+
"tesseract_version": pytesseract.get_tesseract_version(),
|
309 |
+
"available_languages": languages,
|
310 |
+
"current_language": self.language,
|
311 |
+
"test_result": result.strip(),
|
312 |
+
"tesseract_path": pytesseract.pytesseract.tesseract_cmd
|
313 |
+
}
|
314 |
+
|
315 |
+
except Exception as e:
|
316 |
+
return {
|
317 |
+
"status": "error",
|
318 |
+
"error": str(e),
|
319 |
+
"tesseract_path": pytesseract.pytesseract.tesseract_cmd
|
320 |
+
}
|
321 |
+
|
322 |
+
def extract_text(self, file_path):
|
323 |
+
# Dummy implementation for OCR
|
324 |
+
return "OCR functionality not implemented yet."
|
services/vector_store_service.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Dict, Any, Optional, Tuple
|
6 |
+
import faiss
|
7 |
+
from pathlib import Path
|
8 |
+
import asyncio
|
9 |
+
import json
|
10 |
+
|
11 |
+
from core.models import SearchResult, Chunk
|
12 |
+
import config
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class VectorStoreService:
|
17 |
+
def __init__(self):
|
18 |
+
self.config = config.config
|
19 |
+
self.index = None
|
20 |
+
self.chunks_metadata = {} # Maps index position to chunk metadata
|
21 |
+
self.dimension = None
|
22 |
+
|
23 |
+
# Paths
|
24 |
+
self.store_path = Path(self.config.VECTOR_STORE_PATH)
|
25 |
+
self.store_path.mkdir(parents=True, exist_ok=True)
|
26 |
+
|
27 |
+
self.index_path = self.store_path / f"{self.config.INDEX_NAME}.index"
|
28 |
+
self.metadata_path = self.store_path / f"{self.config.INDEX_NAME}_metadata.json"
|
29 |
+
|
30 |
+
# Load existing index if available
|
31 |
+
self._load_index()
|
32 |
+
|
33 |
+
def _load_index(self):
|
34 |
+
"""Load existing FAISS index and metadata"""
|
35 |
+
try:
|
36 |
+
if self.index_path.exists() and self.metadata_path.exists():
|
37 |
+
logger.info("Loading existing FAISS index...")
|
38 |
+
|
39 |
+
# Load FAISS index
|
40 |
+
self.index = faiss.read_index(str(self.index_path))
|
41 |
+
self.dimension = self.index.d
|
42 |
+
|
43 |
+
# Load metadata
|
44 |
+
with open(self.metadata_path, 'r') as f:
|
45 |
+
self.chunks_metadata = json.load(f)
|
46 |
+
|
47 |
+
logger.info(f"Loaded index with {self.index.ntotal} vectors, dimension {self.dimension}")
|
48 |
+
else:
|
49 |
+
logger.info("No existing index found, will create new one")
|
50 |
+
except Exception as e:
|
51 |
+
logger.error(f"Error loading index: {str(e)}")
|
52 |
+
|
53 |
+
def _initialize_index(self, dimension: int):
|
54 |
+
"""Initialize a new FAISS index"""
|
55 |
+
try:
|
56 |
+
# Use IndexFlatIP for cosine similarity (since embeddings are normalized)
|
57 |
+
self.index = faiss.IndexFlatIP(dimension)
|
58 |
+
self.dimension = dimension
|
59 |
+
self.chunks_metadata = {}
|
60 |
+
logger.info(f"Initialized new FAISS index with dimension {dimension}")
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error initializing index: {str(e)}")
|
63 |
+
raise
|
64 |
+
|
65 |
+
async def add_chunks(self, chunks: List[Chunk]) -> bool:
|
66 |
+
"""Add chunks to the vector store"""
|
67 |
+
if not chunks:
|
68 |
+
return True
|
69 |
+
|
70 |
+
try:
|
71 |
+
# Extract embeddings and metadata
|
72 |
+
embeddings = []
|
73 |
+
new_metadata = {}
|
74 |
+
|
75 |
+
for chunk in chunks:
|
76 |
+
if chunk.embedding and len(chunk.embedding) > 0:
|
77 |
+
embeddings.append(chunk.embedding)
|
78 |
+
# Store metadata using the current index position
|
79 |
+
current_index = len(self.chunks_metadata) + len(embeddings) - 1
|
80 |
+
new_metadata[str(current_index)] = {
|
81 |
+
"chunk_id": chunk.id,
|
82 |
+
"document_id": chunk.document_id,
|
83 |
+
"content": chunk.content,
|
84 |
+
"chunk_index": chunk.chunk_index,
|
85 |
+
"start_pos": chunk.start_pos,
|
86 |
+
"end_pos": chunk.end_pos,
|
87 |
+
"metadata": chunk.metadata
|
88 |
+
}
|
89 |
+
|
90 |
+
if not embeddings:
|
91 |
+
logger.warning("No valid embeddings found in chunks")
|
92 |
+
return False
|
93 |
+
|
94 |
+
# Initialize index if needed
|
95 |
+
if self.index is None:
|
96 |
+
self._initialize_index(len(embeddings[0]))
|
97 |
+
|
98 |
+
# Convert to numpy array
|
99 |
+
embeddings_array = np.array(embeddings, dtype=np.float32)
|
100 |
+
|
101 |
+
# Add to FAISS index
|
102 |
+
self.index.add(embeddings_array)
|
103 |
+
|
104 |
+
# Update metadata
|
105 |
+
self.chunks_metadata.update(new_metadata)
|
106 |
+
|
107 |
+
# Save index and metadata
|
108 |
+
await self._save_index()
|
109 |
+
|
110 |
+
logger.info(f"Added {len(embeddings)} chunks to vector store")
|
111 |
+
return True
|
112 |
+
|
113 |
+
except Exception as e:
|
114 |
+
logger.error(f"Error adding chunks to vector store: {str(e)}")
|
115 |
+
return False
|
116 |
+
|
117 |
+
async def search(self, query_embedding: List[float], top_k: int = 5,
|
118 |
+
filters: Optional[Dict[str, Any]] = None) -> List[SearchResult]:
|
119 |
+
"""Search for similar chunks"""
|
120 |
+
if self.index is None or self.index.ntotal == 0:
|
121 |
+
logger.warning("No index available or index is empty")
|
122 |
+
return []
|
123 |
+
|
124 |
+
try:
|
125 |
+
# Convert query embedding to numpy array
|
126 |
+
query_array = np.array([query_embedding], dtype=np.float32)
|
127 |
+
|
128 |
+
# Perform search
|
129 |
+
scores, indices = self.index.search(query_array, min(top_k, self.index.ntotal))
|
130 |
+
|
131 |
+
# Convert results to SearchResult objects
|
132 |
+
results = []
|
133 |
+
for score, idx in zip(scores[0], indices[0]):
|
134 |
+
if idx == -1: # FAISS returns -1 for empty slots
|
135 |
+
continue
|
136 |
+
|
137 |
+
chunk_metadata = self.chunks_metadata.get(str(idx))
|
138 |
+
if chunk_metadata:
|
139 |
+
# Apply filters if specified
|
140 |
+
if filters and not self._apply_filters(chunk_metadata, filters):
|
141 |
+
continue
|
142 |
+
|
143 |
+
result = SearchResult(
|
144 |
+
chunk_id=chunk_metadata["chunk_id"],
|
145 |
+
document_id=chunk_metadata["document_id"],
|
146 |
+
content=chunk_metadata["content"],
|
147 |
+
score=float(score),
|
148 |
+
metadata=chunk_metadata.get("metadata", {})
|
149 |
+
)
|
150 |
+
results.append(result)
|
151 |
+
|
152 |
+
# Sort by score (descending)
|
153 |
+
results.sort(key=lambda x: x.score, reverse=True)
|
154 |
+
|
155 |
+
logger.info(f"Found {len(results)} search results")
|
156 |
+
return results
|
157 |
+
|
158 |
+
except Exception as e:
|
159 |
+
logger.error(f"Error searching vector store: {str(e)}")
|
160 |
+
return []
|
161 |
+
|
162 |
+
def _apply_filters(self, chunk_metadata: Dict[str, Any], filters: Dict[str, Any]) -> bool:
|
163 |
+
"""Apply filters to chunk metadata"""
|
164 |
+
try:
|
165 |
+
for key, value in filters.items():
|
166 |
+
if key == "document_id":
|
167 |
+
if chunk_metadata.get("document_id") != value:
|
168 |
+
return False
|
169 |
+
elif key == "document_ids":
|
170 |
+
if chunk_metadata.get("document_id") not in value:
|
171 |
+
return False
|
172 |
+
elif key == "content_length_min":
|
173 |
+
if len(chunk_metadata.get("content", "")) < value:
|
174 |
+
return False
|
175 |
+
elif key == "content_length_max":
|
176 |
+
if len(chunk_metadata.get("content", "")) > value:
|
177 |
+
return False
|
178 |
+
# Add more filter types as needed
|
179 |
+
|
180 |
+
return True
|
181 |
+
except Exception as e:
|
182 |
+
logger.error(f"Error applying filters: {str(e)}")
|
183 |
+
return True
|
184 |
+
|
185 |
+
async def _save_index(self):
|
186 |
+
"""Save the FAISS index and metadata to disk"""
|
187 |
+
try:
|
188 |
+
if self.index is not None:
|
189 |
+
# Save FAISS index
|
190 |
+
faiss.write_index(self.index, str(self.index_path))
|
191 |
+
|
192 |
+
# Save metadata
|
193 |
+
with open(self.metadata_path, 'w') as f:
|
194 |
+
json.dump(self.chunks_metadata, f, indent=2)
|
195 |
+
|
196 |
+
logger.debug("Saved index and metadata to disk")
|
197 |
+
except Exception as e:
|
198 |
+
logger.error(f"Error saving index: {str(e)}")
|
199 |
+
|
200 |
+
async def get_stats(self) -> Dict[str, Any]:
|
201 |
+
"""Get statistics about the vector store"""
|
202 |
+
try:
|
203 |
+
return {
|
204 |
+
"total_vectors": self.index.ntotal if self.index else 0,
|
205 |
+
"dimension": self.dimension,
|
206 |
+
"index_type": type(self.index).__name__ if self.index else None,
|
207 |
+
"metadata_entries": len(self.chunks_metadata),
|
208 |
+
"index_file_exists": self.index_path.exists(),
|
209 |
+
"metadata_file_exists": self.metadata_path.exists()
|
210 |
+
}
|
211 |
+
except Exception as e:
|
212 |
+
logger.error(f"Error getting stats: {str(e)}")
|
213 |
+
return {"error": str(e)}
|
214 |
+
|
215 |
+
async def delete_document(self, document_id: str) -> bool:
|
216 |
+
"""Delete all chunks for a specific document"""
|
217 |
+
try:
|
218 |
+
# Find indices to remove
|
219 |
+
indices_to_remove = []
|
220 |
+
for idx, metadata in self.chunks_metadata.items():
|
221 |
+
if metadata.get("document_id") == document_id:
|
222 |
+
indices_to_remove.append(int(idx))
|
223 |
+
|
224 |
+
if not indices_to_remove:
|
225 |
+
logger.warning(f"No chunks found for document {document_id}")
|
226 |
+
return False
|
227 |
+
|
228 |
+
# FAISS doesn't support removing individual vectors efficiently
|
229 |
+
# We need to rebuild the index without the removed vectors
|
230 |
+
if self.index and self.index.ntotal > 0:
|
231 |
+
# Get all embeddings except the ones to remove
|
232 |
+
all_embeddings = []
|
233 |
+
new_metadata = {}
|
234 |
+
new_index = 0
|
235 |
+
|
236 |
+
for old_idx in range(self.index.ntotal):
|
237 |
+
if old_idx not in indices_to_remove:
|
238 |
+
# Get the embedding from FAISS
|
239 |
+
embedding = self.index.reconstruct(old_idx)
|
240 |
+
all_embeddings.append(embedding)
|
241 |
+
|
242 |
+
# Update metadata with new index
|
243 |
+
old_metadata = self.chunks_metadata.get(str(old_idx))
|
244 |
+
if old_metadata:
|
245 |
+
new_metadata[str(new_index)] = old_metadata
|
246 |
+
new_index += 1
|
247 |
+
|
248 |
+
# Rebuild index
|
249 |
+
if all_embeddings:
|
250 |
+
self._initialize_index(self.dimension)
|
251 |
+
embeddings_array = np.array(all_embeddings, dtype=np.float32)
|
252 |
+
self.index.add(embeddings_array)
|
253 |
+
self.chunks_metadata = new_metadata
|
254 |
+
else:
|
255 |
+
# No embeddings left, create empty index
|
256 |
+
self._initialize_index(self.dimension)
|
257 |
+
|
258 |
+
# Save updated index
|
259 |
+
await self._save_index()
|
260 |
+
|
261 |
+
logger.info(f"Deleted {len(indices_to_remove)} chunks for document {document_id}")
|
262 |
+
return True
|
263 |
+
|
264 |
+
except Exception as e:
|
265 |
+
logger.error(f"Error deleting document chunks: {str(e)}")
|
266 |
+
return False
|
267 |
+
|
268 |
+
async def clear_all(self) -> bool:
|
269 |
+
"""Clear all data from the vector store"""
|
270 |
+
try:
|
271 |
+
self.index = None
|
272 |
+
self.chunks_metadata = {}
|
273 |
+
self.dimension = None
|
274 |
+
|
275 |
+
# Remove files
|
276 |
+
if self.index_path.exists():
|
277 |
+
self.index_path.unlink()
|
278 |
+
if self.metadata_path.exists():
|
279 |
+
self.metadata_path.unlink()
|
280 |
+
|
281 |
+
logger.info("Cleared all data from vector store")
|
282 |
+
return True
|
283 |
+
except Exception as e:
|
284 |
+
logger.error(f"Error clearing vector store: {str(e)}")
|
285 |
+
return False
|