Create agent.py
Browse files
agent.py
ADDED
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import threading
|
5 |
+
import time
|
6 |
+
import re
|
7 |
+
import atexit
|
8 |
+
from contextlib import asynccontextmanager
|
9 |
+
from typing import Any, Optional, List, Dict, Tuple, Callable
|
10 |
+
from smolagents import CodeAgent, MCPClient
|
11 |
+
from smolagents.models import Model
|
12 |
+
from inference import initialize, generate_content
|
13 |
+
from workflow_vizualizer import track_workflow_step, track_communication, complete_workflow_step
|
14 |
+
|
15 |
+
# Global session management
|
16 |
+
_session_initialized = False
|
17 |
+
_session_lock = threading.Lock()
|
18 |
+
_session_start_time = None
|
19 |
+
|
20 |
+
# Enhanced global caching for Phase 2 optimizations with async support
|
21 |
+
_global_tools_cache = {}
|
22 |
+
_global_tools_timestamp = None
|
23 |
+
_global_model_instance = None
|
24 |
+
_global_model_lock = threading.Lock()
|
25 |
+
_global_connection_pool = {}
|
26 |
+
_global_connection_lock = threading.Lock()
|
27 |
+
|
28 |
+
# Managed event loop system
|
29 |
+
_managed_event_loop = None
|
30 |
+
_event_loop_lock = threading.Lock()
|
31 |
+
_event_loop_manager = None # Global event loop manager instance
|
32 |
+
|
33 |
+
|
34 |
+
@asynccontextmanager
|
35 |
+
async def managed_event_loop():
|
36 |
+
"""Proper async context manager for event loop lifecycle."""
|
37 |
+
global _managed_event_loop
|
38 |
+
|
39 |
+
try:
|
40 |
+
# Create new event loop if needed
|
41 |
+
if _managed_event_loop is None or _managed_event_loop.is_closed():
|
42 |
+
_managed_event_loop = asyncio.new_event_loop()
|
43 |
+
asyncio.set_event_loop(_managed_event_loop)
|
44 |
+
|
45 |
+
print("✅ Event loop initialized and set as current")
|
46 |
+
yield _managed_event_loop
|
47 |
+
|
48 |
+
except Exception as e:
|
49 |
+
print(f"❌ Event loop error: {e}")
|
50 |
+
raise
|
51 |
+
finally:
|
52 |
+
# Don't close the loop here - let it be managed at a higher level
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
async def safe_async_call(coroutine, timeout=30):
|
57 |
+
"""Safely execute async calls with proper error handling."""
|
58 |
+
try:
|
59 |
+
return await asyncio.wait_for(coroutine, timeout=timeout)
|
60 |
+
except asyncio.TimeoutError:
|
61 |
+
print(f"⏱️ Async call timed out after {timeout}s")
|
62 |
+
raise
|
63 |
+
except RuntimeError as e:
|
64 |
+
if "Event loop is closed" in str(e):
|
65 |
+
print("🔄 Event loop closed - attempting to create new one")
|
66 |
+
# Create new event loop and retry
|
67 |
+
loop = asyncio.new_event_loop()
|
68 |
+
asyncio.set_event_loop(loop)
|
69 |
+
return await asyncio.wait_for(coroutine, timeout=timeout)
|
70 |
+
raise
|
71 |
+
|
72 |
+
|
73 |
+
class AsyncEventLoopManager:
|
74 |
+
def __init__(self):
|
75 |
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
76 |
+
self._thread: Optional[threading.Thread] = None
|
77 |
+
self._loop = asyncio.new_event_loop()
|
78 |
+
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
79 |
+
self._thread.start()
|
80 |
+
print("AsyncEventLoopManager: Initialized and thread started.")
|
81 |
+
|
82 |
+
def _run_loop(self):
|
83 |
+
if self._loop is None:
|
84 |
+
print("AsyncEventLoopManager: _run_loop called but loop is None.")
|
85 |
+
return
|
86 |
+
asyncio.set_event_loop(self._loop)
|
87 |
+
try:
|
88 |
+
print("AsyncEventLoopManager: Event loop running.")
|
89 |
+
self._loop.run_forever()
|
90 |
+
except Exception as e:
|
91 |
+
print(f"AsyncEventLoopManager: Exception in event loop: {e}")
|
92 |
+
finally:
|
93 |
+
# Ensure the loop is stopped if it was running.
|
94 |
+
# The actual closing is handled by the shutdown() method.
|
95 |
+
if self._loop and self._loop.is_running():
|
96 |
+
self._loop.stop()
|
97 |
+
print("AsyncEventLoopManager: Event loop stopped in _run_loop finally.")
|
98 |
+
|
99 |
+
def run_async(self, coro):
|
100 |
+
"""Run a coroutine in the event loop from another thread."""
|
101 |
+
coro_name = getattr(coro, '__name__', str(coro))
|
102 |
+
if self._loop is None:
|
103 |
+
print(f"AsyncEventLoopManager: Loop object is None. Cannot run coroutine {coro_name}.")
|
104 |
+
raise RuntimeError("Event loop manager is not properly initialized (loop missing).")
|
105 |
+
|
106 |
+
if self._loop.is_closed():
|
107 |
+
print(f"AsyncEventLoopManager: Loop is CLOSED. Cannot schedule coroutine {coro_name}.")
|
108 |
+
raise RuntimeError(f"Event loop is closed. Cannot run {coro_name}.")
|
109 |
+
|
110 |
+
if self._thread is None or not self._thread.is_alive():
|
111 |
+
print(f"AsyncEventLoopManager: Event loop thread is not alive or None. Cannot run coroutine {coro_name}.")
|
112 |
+
raise RuntimeError("Event loop thread is not alive or None.")
|
113 |
+
|
114 |
+
try:
|
115 |
+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
116 |
+
return future.result(timeout=30) # Assuming a 30s timeout
|
117 |
+
except RuntimeError as e:
|
118 |
+
print(f"AsyncEventLoopManager: RuntimeError during run_coroutine_threadsafe for {coro_name}: {e}")
|
119 |
+
raise
|
120 |
+
except asyncio.TimeoutError:
|
121 |
+
print(f"AsyncEventLoopManager: Timeout waiting for coroutine {coro_name} result.")
|
122 |
+
raise
|
123 |
+
except Exception as e:
|
124 |
+
print(f"AsyncEventLoopManager: Error submitting coroutine {coro_name}: {e}")
|
125 |
+
raise
|
126 |
+
|
127 |
+
def shutdown(self):
|
128 |
+
"""Stop and close the event loop."""
|
129 |
+
print("AsyncEventLoopManager: Shutdown initiated.")
|
130 |
+
if self._loop and not self._loop.is_closed():
|
131 |
+
if self._loop.is_running():
|
132 |
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
133 |
+
print("AsyncEventLoopManager: Stop signal sent to running event loop.")
|
134 |
+
else:
|
135 |
+
print("AsyncEventLoopManager: Event loop was not running, but attempting to stop.")
|
136 |
+
# If not running, stop might not be necessary or might error,
|
137 |
+
# but call_soon_threadsafe should be safe.
|
138 |
+
try:
|
139 |
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
140 |
+
except RuntimeError as e:
|
141 |
+
print(f"AsyncEventLoopManager: Info - could not send stop to non-running loop: {e}")
|
142 |
+
|
143 |
+
if self._thread and self._thread.is_alive():
|
144 |
+
self._thread.join(timeout=10)
|
145 |
+
if self._thread.is_alive():
|
146 |
+
print("AsyncEventLoopManager: Thread did not join in time during shutdown.")
|
147 |
+
else:
|
148 |
+
print("AsyncEventLoopManager: Thread joined.")
|
149 |
+
else:
|
150 |
+
print("AsyncEventLoopManager: Thread already stopped, not initialized, or None at shutdown.")
|
151 |
+
|
152 |
+
# Explicitly close the loop here after the thread has finished.
|
153 |
+
if self._loop and not self._loop.is_closed():
|
154 |
+
try:
|
155 |
+
# Ensure all tasks are cancelled before closing
|
156 |
+
# Gather all tasks:
|
157 |
+
if sys.version_info >= (3, 7): # gather works on all tasks in 3.7+
|
158 |
+
tasks = asyncio.all_tasks(self._loop)
|
159 |
+
for task in tasks:
|
160 |
+
task.cancel()
|
161 |
+
# Wait for tasks to cancel - this should be done within the loop's thread ideally
|
162 |
+
# but since we are shutting down from outside, this is a best effort.
|
163 |
+
# self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
|
164 |
+
self._loop.close()
|
165 |
+
print("AsyncEventLoopManager: Event loop closed in shutdown.")
|
166 |
+
except Exception as e:
|
167 |
+
print(f"AsyncEventLoopManager: Exception while closing loop: {e}")
|
168 |
+
elif self._loop and self._loop.is_closed():
|
169 |
+
print("AsyncEventLoopManager: Event loop was already closed.")
|
170 |
+
else:
|
171 |
+
print("AsyncEventLoopManager: No loop to close or loop was None.")
|
172 |
+
|
173 |
+
self._loop = None
|
174 |
+
self._thread = None
|
175 |
+
print("AsyncEventLoopManager: Shutdown process complete.")
|
176 |
+
|
177 |
+
def get_event_loop_manager():
|
178 |
+
"""Get or create the global event loop manager."""
|
179 |
+
global _event_loop_manager
|
180 |
+
|
181 |
+
with _event_loop_lock:
|
182 |
+
# Check if manager exists and its loop and thread are valid
|
183 |
+
manager_valid = False
|
184 |
+
if _event_loop_manager is not None:
|
185 |
+
# More robust check: loop exists, is not closed, thread exists and is alive
|
186 |
+
if _event_loop_manager._loop is not None and \
|
187 |
+
not _event_loop_manager._loop.is_closed() and \
|
188 |
+
_event_loop_manager._thread is not None and \
|
189 |
+
_event_loop_manager._thread.is_alive():
|
190 |
+
manager_valid = True
|
191 |
+
else:
|
192 |
+
print("get_event_loop_manager: Existing manager found but its loop or thread is invalid. Recreating.")
|
193 |
+
try:
|
194 |
+
_event_loop_manager.shutdown() # Attempt to clean up the old one
|
195 |
+
except Exception as e:
|
196 |
+
print(f"get_event_loop_manager: Error shutting down invalid manager: {e}")
|
197 |
+
_event_loop_manager = None # Ensure it's None so a new one is created
|
198 |
+
|
199 |
+
if _event_loop_manager is None: # Covers both initial creation and recreation
|
200 |
+
print("get_event_loop_manager: Creating new AsyncEventLoopManager instance.")
|
201 |
+
_event_loop_manager = AsyncEventLoopManager()
|
202 |
+
else:
|
203 |
+
print("get_event_loop_manager: Reusing existing valid AsyncEventLoopManager instance.")
|
204 |
+
return _event_loop_manager
|
205 |
+
|
206 |
+
def shutdown_event_loop_manager():
|
207 |
+
"""Shutdown the global event loop manager."""
|
208 |
+
global _event_loop_manager
|
209 |
+
with _event_loop_lock:
|
210 |
+
if _event_loop_manager is not None:
|
211 |
+
print("shutdown_event_loop_manager: Shutting down global event loop manager.")
|
212 |
+
try:
|
213 |
+
_event_loop_manager.shutdown()
|
214 |
+
except Exception as e:
|
215 |
+
print(f"shutdown_event_loop_manager: Error during shutdown: {e}")
|
216 |
+
finally:
|
217 |
+
_event_loop_manager = None
|
218 |
+
else:
|
219 |
+
print("shutdown_event_loop_manager: No active event loop manager to shut down.")
|
220 |
+
|
221 |
+
class AsyncMCPClientWrapper:
|
222 |
+
"""Wrapper for async MCP client operations."""
|
223 |
+
|
224 |
+
def __init__(self, url: str):
|
225 |
+
self.url = url
|
226 |
+
self._mcp_client = None
|
227 |
+
self._tools = None
|
228 |
+
self._tools_cache_time = None
|
229 |
+
self._cache_ttl = 300 # 5 minutes cache
|
230 |
+
self._connected = False
|
231 |
+
|
232 |
+
async def ensure_connected(self):
|
233 |
+
"""Ensure async connection is established."""
|
234 |
+
if not self._connected or self._mcp_client is None:
|
235 |
+
try:
|
236 |
+
# Create MCP client with SSE transport for Gradio
|
237 |
+
self._mcp_client = MCPClient({"url": self.url, "transport": "sse"})
|
238 |
+
# Attempt a lightweight operation to confirm connectivity, e.g., get_tools or a custom ping
|
239 |
+
# For now, we assume MCPClient constructor success implies basic connectivity.
|
240 |
+
# If get_tools is lightweight enough, it can be called here.
|
241 |
+
# await self._mcp_client.get_tools() # Example, if get_tools is async and suitable
|
242 |
+
self._connected = True
|
243 |
+
print(f"✅ Connected to MCP server: {self.url}")
|
244 |
+
except Exception as e:
|
245 |
+
self._connected = False
|
246 |
+
print(f"❌ Failed to connect to {self.url}: {e}")
|
247 |
+
raise
|
248 |
+
|
249 |
+
async def get_tools(self):
|
250 |
+
"""Get tools asynchronously."""
|
251 |
+
current_time = time.time()
|
252 |
+
|
253 |
+
# Check instance cache
|
254 |
+
if (self._tools is not None and
|
255 |
+
self._tools_cache_time is not None and
|
256 |
+
current_time - self._tools_cache_time < self._cache_ttl):
|
257 |
+
return self._tools
|
258 |
+
|
259 |
+
# Fetch fresh tools
|
260 |
+
await self.ensure_connected() # Ensures client is connected
|
261 |
+
|
262 |
+
if self._mcp_client is None: # Should be caught by ensure_connected, but as a safeguard
|
263 |
+
raise RuntimeError("MCP client not connected")
|
264 |
+
|
265 |
+
try:
|
266 |
+
# Assuming MCPClient.get_tools() is a synchronous method based on original structure
|
267 |
+
# If it were async, it would be `await self._mcp_client.get_tools()`
|
268 |
+
self._tools = self._mcp_client.get_tools()
|
269 |
+
self._tools_cache_time = current_time
|
270 |
+
tool_names = [tool.name for tool in self._tools] if self._tools else []
|
271 |
+
print(f"🔧 Fetched {len(tool_names)} tools from {self.url}: {tool_names}")
|
272 |
+
|
273 |
+
return self._tools
|
274 |
+
except Exception as e:
|
275 |
+
print(f"❌ Error fetching tools from {self.url}: {e}")
|
276 |
+
# Potentially reset connection status if error indicates a connection problem
|
277 |
+
# self._connected = False
|
278 |
+
raise
|
279 |
+
|
280 |
+
async def disconnect(self):
|
281 |
+
"""Gracefully disconnect."""
|
282 |
+
if self._mcp_client and self._connected:
|
283 |
+
try:
|
284 |
+
# Assuming MCPClient.disconnect() is synchronous
|
285 |
+
# If it were async, it would be `await self._mcp_client.disconnect()`
|
286 |
+
self._mcp_client.disconnect()
|
287 |
+
except Exception as e:
|
288 |
+
print(f"Error during MCPClient disconnect for {self.url}: {e}")
|
289 |
+
# Log error but continue to mark as disconnected
|
290 |
+
pass # Fall through to set _connected = False
|
291 |
+
self._connected = False
|
292 |
+
self._mcp_client = None
|
293 |
+
print(f"🔌 Disconnected from MCP server: {self.url}")
|
294 |
+
|
295 |
+
class AsyncPersistentMCPClient:
|
296 |
+
"""Async-aware persistent MCP client that survives multiple requests."""
|
297 |
+
|
298 |
+
def __init__(self, url: str):
|
299 |
+
self.url = url
|
300 |
+
self._wrapper = AsyncMCPClientWrapper(url)
|
301 |
+
self._loop_manager = None
|
302 |
+
|
303 |
+
def ensure_connected(self):
|
304 |
+
"""Sync wrapper for async connection."""
|
305 |
+
if self._loop_manager is None:
|
306 |
+
self._loop_manager = get_event_loop_manager()
|
307 |
+
|
308 |
+
conn_step = track_communication("agent", "mcp_client", "connection_ensure", f"Ensuring connection to {self.url}")
|
309 |
+
try:
|
310 |
+
# Ensure we have a valid loop manager
|
311 |
+
if self._loop_manager is None:
|
312 |
+
self._loop_manager = get_event_loop_manager()
|
313 |
+
|
314 |
+
# Additional safety check
|
315 |
+
if self._loop_manager is None:
|
316 |
+
raise RuntimeError("Failed to create event loop manager")
|
317 |
+
|
318 |
+
# Pass the coroutine object itself, not its result
|
319 |
+
self._loop_manager.run_async(self._wrapper.ensure_connected())
|
320 |
+
complete_workflow_step(conn_step, "completed", details={"url": self.url})
|
321 |
+
except Exception as e:
|
322 |
+
complete_workflow_step(conn_step, "error", details={"error": str(e)})
|
323 |
+
raise
|
324 |
+
|
325 |
+
def get_client(self):
|
326 |
+
"""Get the underlying MCP client."""
|
327 |
+
self.ensure_connected()
|
328 |
+
return self._wrapper._mcp_client
|
329 |
+
|
330 |
+
def get_tools(self):
|
331 |
+
"""Get tools with enhanced caching and async support."""
|
332 |
+
global _global_tools_cache, _global_tools_timestamp
|
333 |
+
current_time = time.time()
|
334 |
+
|
335 |
+
if self._loop_manager is None:
|
336 |
+
self._loop_manager = get_event_loop_manager()
|
337 |
+
|
338 |
+
# Phase 2 Optimization: Check server-specific global cache first
|
339 |
+
with _global_connection_lock:
|
340 |
+
server_cache_key = self.url
|
341 |
+
server_cache = _global_tools_cache.get(server_cache_key, {})
|
342 |
+
|
343 |
+
if (server_cache and _global_tools_timestamp and
|
344 |
+
current_time - _global_tools_timestamp < 300):
|
345 |
+
# Track global cache hit
|
346 |
+
cache_step = track_communication("mcp_client", "mcp_server", "cache_hit_global", f"Using global cached tools for {self.url}")
|
347 |
+
complete_workflow_step(cache_step, "completed", details={
|
348 |
+
"tools": list(server_cache.keys()),
|
349 |
+
"cache_type": "global_server_specific",
|
350 |
+
"server_url": self.url,
|
351 |
+
"cache_age": current_time - _global_tools_timestamp
|
352 |
+
})
|
353 |
+
return list(server_cache.values())
|
354 |
+
|
355 |
+
# Fetch fresh tools using async
|
356 |
+
tools_step = track_communication("mcp_client", "mcp_server", "get_tools", f"Fetching tools from {self.url} (cache refresh)")
|
357 |
+
try:
|
358 |
+
# Ensure we have a valid loop manager
|
359 |
+
if self._loop_manager is None:
|
360 |
+
self._loop_manager = get_event_loop_manager()
|
361 |
+
|
362 |
+
# Additional safety check
|
363 |
+
if self._loop_manager is None:
|
364 |
+
raise RuntimeError("Failed to create event loop manager")
|
365 |
+
|
366 |
+
# Pass the coroutine object itself
|
367 |
+
tools = self._loop_manager.run_async(self._wrapper.get_tools())
|
368 |
+
|
369 |
+
# Update global cache
|
370 |
+
with _global_connection_lock:
|
371 |
+
if tools:
|
372 |
+
if server_cache_key not in _global_tools_cache:
|
373 |
+
_global_tools_cache[server_cache_key] = {}
|
374 |
+
|
375 |
+
_global_tools_cache[server_cache_key] = {tool.name: tool for tool in tools}
|
376 |
+
_global_tools_timestamp = current_time
|
377 |
+
|
378 |
+
total_tools = sum(len(server_tools) for server_tools in _global_tools_cache.values())
|
379 |
+
print(f"🔄 Global tools cache updated for {self.url}: {len(tools)} tools")
|
380 |
+
print(f" Total cached tools across all servers: {total_tools}")
|
381 |
+
|
382 |
+
tool_names = [tool.name for tool in tools] if tools else []
|
383 |
+
complete_workflow_step(tools_step, "completed", details={
|
384 |
+
"tools": tool_names,
|
385 |
+
"count": len(tool_names),
|
386 |
+
"server_url": self.url,
|
387 |
+
"cache_status": "refreshed_server_specific",
|
388 |
+
"global_cache_servers": len(_global_tools_cache)
|
389 |
+
})
|
390 |
+
return tools
|
391 |
+
|
392 |
+
except Exception as e:
|
393 |
+
complete_workflow_step(tools_step, "error", details={"error": str(e), "server_url": self.url})
|
394 |
+
raise
|
395 |
+
|
396 |
+
def disconnect(self):
|
397 |
+
"""Gracefully disconnect."""
|
398 |
+
if self._loop_manager and self._wrapper:
|
399 |
+
try:
|
400 |
+
# Ensure we have a valid loop manager
|
401 |
+
if self._loop_manager is None:
|
402 |
+
self._loop_manager = get_event_loop_manager()
|
403 |
+
|
404 |
+
# Additional safety check
|
405 |
+
if self._loop_manager is None:
|
406 |
+
raise RuntimeError("Failed to create event loop manager")
|
407 |
+
|
408 |
+
# Pass the coroutine object itself
|
409 |
+
self._loop_manager.run_async(self._wrapper.disconnect())
|
410 |
+
except RuntimeError as e:
|
411 |
+
# Handle cases where the loop might already be closed or unable to run tasks
|
412 |
+
print(f"AsyncPersistentMCPClient: Error running disconnect for {self.url} in async loop: {e}")
|
413 |
+
except Exception as e:
|
414 |
+
print(f"AsyncPersistentMCPClient: General error during disconnect for {self.url}: {e}")
|
415 |
+
|
416 |
+
def get_mcp_client(url: str = "https://NLarchive-Agent-client-multi-mcp-SKT.hf.space/gradio_api/mcp/sse") -> AsyncPersistentMCPClient:
|
417 |
+
"""Get or create an MCP client with enhanced global connection pooling."""
|
418 |
+
# Phase 2 Optimization: Use global connection pool
|
419 |
+
with _global_connection_lock:
|
420 |
+
if url not in _global_connection_pool:
|
421 |
+
conn_step = track_communication("agent", "mcp_client", "connection_create", f"Creating new global connection to {url}")
|
422 |
+
_global_connection_pool[url] = AsyncPersistentMCPClient(url)
|
423 |
+
complete_workflow_step(conn_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
|
424 |
+
else:
|
425 |
+
# Track connection reuse
|
426 |
+
reuse_step = track_communication("agent", "mcp_client", "connection_reuse", f"Reusing global connection to {url}")
|
427 |
+
complete_workflow_step(reuse_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
|
428 |
+
|
429 |
+
return _global_connection_pool[url]
|
430 |
+
|
431 |
+
def get_global_model() -> 'CachedLocalInferenceModel':
|
432 |
+
"""Get or create global model instance for Phase 2 optimization."""
|
433 |
+
global _global_model_instance
|
434 |
+
|
435 |
+
with _global_model_lock:
|
436 |
+
if _global_model_instance is None:
|
437 |
+
model_step = track_workflow_step("model_init_global", "Initializing global model instance")
|
438 |
+
|
439 |
+
# CRITICAL FIX: Create and assign BEFORE initialization
|
440 |
+
_global_model_instance = CachedLocalInferenceModel()
|
441 |
+
|
442 |
+
# Now initialize the model
|
443 |
+
try:
|
444 |
+
_global_model_instance.ensure_initialized()
|
445 |
+
complete_workflow_step(model_step, "completed", details={"model_type": "global_cached"})
|
446 |
+
print(f"🤖 Global model instance created and initialized")
|
447 |
+
except Exception as e:
|
448 |
+
# If initialization fails, reset global instance
|
449 |
+
_global_model_instance = None
|
450 |
+
complete_workflow_step(model_step, "error", details={"error": str(e)})
|
451 |
+
raise
|
452 |
+
else:
|
453 |
+
# Track model reuse
|
454 |
+
reuse_step = track_workflow_step("model_reuse", "Reusing global model instance")
|
455 |
+
complete_workflow_step(reuse_step, "completed", details={"model_type": "global_cached"})
|
456 |
+
|
457 |
+
return _global_model_instance
|
458 |
+
|
459 |
+
def reset_global_state():
|
460 |
+
"""Reset global state for testing purposes with server-specific cache awareness."""
|
461 |
+
global _global_tools_cache, _global_tools_timestamp, _global_model_instance, _global_connection_pool, _event_loop_manager
|
462 |
+
|
463 |
+
with _global_connection_lock:
|
464 |
+
# FIXED: Clear server-specific cache structure (don't rebind!)
|
465 |
+
_global_tools_cache.clear() # Now clears {url: {tool_name: tool}} structure
|
466 |
+
_global_tools_timestamp = None
|
467 |
+
|
468 |
+
# Disconnect all connections but keep pool structure
|
469 |
+
for client in _global_connection_pool.values():
|
470 |
+
try:
|
471 |
+
client.disconnect()
|
472 |
+
except:
|
473 |
+
pass
|
474 |
+
|
475 |
+
with _global_model_lock:
|
476 |
+
# Don't reset model instance - it should persist
|
477 |
+
pass
|
478 |
+
|
479 |
+
print("🔄 Global state reset for testing (server-specific cache cleared)")
|
480 |
+
|
481 |
+
# Enhanced LocalInferenceModel with workflow tracking
|
482 |
+
class CachedLocalInferenceModel(Model):
|
483 |
+
"""Model with enhanced caching and session persistence."""
|
484 |
+
|
485 |
+
def __init__(self):
|
486 |
+
super().__init__()
|
487 |
+
self._response_cache = {}
|
488 |
+
self._cache_hits = 0
|
489 |
+
self._cache_misses = 0
|
490 |
+
self._model_ready = False
|
491 |
+
|
492 |
+
def ensure_initialized(self):
|
493 |
+
"""Lazy initialization of the model."""
|
494 |
+
if not self._model_ready:
|
495 |
+
init_step = track_workflow_step("model_init", "Initializing inference model (lazy)")
|
496 |
+
try:
|
497 |
+
initialize()
|
498 |
+
self._model_ready = True
|
499 |
+
complete_workflow_step(init_step, "completed")
|
500 |
+
except Exception as e:
|
501 |
+
complete_workflow_step(init_step, "error", details={"error": str(e)})
|
502 |
+
raise
|
503 |
+
|
504 |
+
def generate(self, messages: Any, **kwargs: Any) -> Any:
|
505 |
+
self.ensure_initialized()
|
506 |
+
|
507 |
+
prompt = self._format_messages(messages)
|
508 |
+
|
509 |
+
# Enhanced cache with hash-based lookup
|
510 |
+
cache_key = hash(prompt)
|
511 |
+
if cache_key in self._response_cache:
|
512 |
+
self._cache_hits += 1
|
513 |
+
cached_response = self._response_cache[cache_key]
|
514 |
+
|
515 |
+
# Track cache hit
|
516 |
+
cache_step = track_communication("agent", "llm_service", "cache_hit", "Using cached response")
|
517 |
+
complete_workflow_step(cache_step, "completed", details={
|
518 |
+
"cache_hits": self._cache_hits,
|
519 |
+
"cache_misses": self._cache_misses,
|
520 |
+
"cache_ratio": self._cache_hits / (self._cache_hits + self._cache_misses)
|
521 |
+
})
|
522 |
+
|
523 |
+
return ModelResponse(cached_response.content, prompt)
|
524 |
+
|
525 |
+
self._cache_misses += 1
|
526 |
+
|
527 |
+
# Track LLM call
|
528 |
+
llm_step = track_communication("agent", "llm_service", "generate_request", "Generating new response")
|
529 |
+
|
530 |
+
try:
|
531 |
+
enhanced_prompt = self._enhance_prompt_for_tools(prompt)
|
532 |
+
|
533 |
+
response_text = generate_content(
|
534 |
+
prompt=enhanced_prompt,
|
535 |
+
model_name=kwargs.get('model_name'),
|
536 |
+
allow_fallbacks=True,
|
537 |
+
generation_config={
|
538 |
+
'temperature': kwargs.get('temperature', 0.3),
|
539 |
+
'max_output_tokens': kwargs.get('max_tokens', 512)
|
540 |
+
}
|
541 |
+
)
|
542 |
+
|
543 |
+
# Validate and fix response format
|
544 |
+
if not self._is_valid_code_response(response_text):
|
545 |
+
response_text = self._fix_response_format(response_text, prompt)
|
546 |
+
|
547 |
+
response = ModelResponse(str(response_text), prompt)
|
548 |
+
|
549 |
+
# Smart cache management (keep most recent 10 responses)
|
550 |
+
if len(self._response_cache) >= 10:
|
551 |
+
# Remove oldest entry (simple FIFO)
|
552 |
+
oldest_key = next(iter(self._response_cache))
|
553 |
+
del self._response_cache[oldest_key]
|
554 |
+
|
555 |
+
self._response_cache[cache_key] = response
|
556 |
+
|
557 |
+
complete_workflow_step(llm_step, "completed", details={
|
558 |
+
"cache_status": "new",
|
559 |
+
"input_tokens": response.token_usage.input_tokens,
|
560 |
+
"output_tokens": response.token_usage.output_tokens,
|
561 |
+
"model": response.model
|
562 |
+
})
|
563 |
+
|
564 |
+
return response
|
565 |
+
|
566 |
+
except Exception as e:
|
567 |
+
fallback_response = self._create_fallback_response(prompt, str(e))
|
568 |
+
complete_workflow_step(llm_step, "error", details={"error": str(e)})
|
569 |
+
return ModelResponse(fallback_response, prompt)
|
570 |
+
|
571 |
+
def _enhance_prompt_for_tools(self, prompt: str) -> str:
|
572 |
+
"""Enhance the prompt with better tool usage examples."""
|
573 |
+
if "sentiment" in prompt.lower():
|
574 |
+
tool_example = """
|
575 |
+
IMPORTANT: When calling sentiment_analysis, use keyword arguments only:
|
576 |
+
Correct: sentiment_analysis(text="your text here")
|
577 |
+
Wrong: sentiment_analysis("your text here")
|
578 |
+
|
579 |
+
Example:
|
580 |
+
```py
|
581 |
+
text = "this is horrible"
|
582 |
+
result = sentiment_analysis(text=text)
|
583 |
+
final_answer(result)
|
584 |
+
```"""
|
585 |
+
return prompt + "\n" + tool_example
|
586 |
+
return prompt
|
587 |
+
|
588 |
+
def _format_messages(self, messages: Any) -> str:
|
589 |
+
"""Convert messages to a single prompt string."""
|
590 |
+
if isinstance(messages, str):
|
591 |
+
return messages
|
592 |
+
elif isinstance(messages, list):
|
593 |
+
prompt_parts = []
|
594 |
+
for msg in messages:
|
595 |
+
if isinstance(msg, dict):
|
596 |
+
if 'content' in msg:
|
597 |
+
content = msg['content']
|
598 |
+
role = msg.get('role', 'user')
|
599 |
+
if isinstance(content, list):
|
600 |
+
text_parts = [part.get('text', '') for part in content if part.get('type') == 'text']
|
601 |
+
content = ' '.join(text_parts)
|
602 |
+
prompt_parts.append(f"{role}: {content}")
|
603 |
+
elif 'text' in msg:
|
604 |
+
prompt_parts.append(msg['text'])
|
605 |
+
elif hasattr(msg, 'content'):
|
606 |
+
prompt_parts.append(str(msg.content))
|
607 |
+
else:
|
608 |
+
prompt_parts.append(str(msg))
|
609 |
+
return '\n'.join(prompt_parts)
|
610 |
+
else:
|
611 |
+
return str(messages)
|
612 |
+
|
613 |
+
def _is_valid_code_response(self, response: str) -> bool:
|
614 |
+
"""Check if response contains valid code block format."""
|
615 |
+
code_pattern = r'```(?:py|python)?\s*\n(.*?)\n```'
|
616 |
+
return bool(re.search(code_pattern, response, re.DOTALL))
|
617 |
+
|
618 |
+
def _fix_response_format(self, response: str, original_prompt: str) -> str:
|
619 |
+
"""Try to fix response format to match expected pattern."""
|
620 |
+
# Attempt to remove or comment out "Thoughts:" if not in a code block already
|
621 |
+
# This is a common source of SyntaxError if the LLM includes it directly
|
622 |
+
if "Thoughts:" in response and not "```" in response.split("Thoughts:")[0]:
|
623 |
+
# If "Thoughts:" appears before any code block, comment it out
|
624 |
+
response = response.replace("Thoughts:", "# Thoughts:", 1)
|
625 |
+
|
626 |
+
if "sentiment" in original_prompt.lower():
|
627 |
+
text_to_analyze = "neutral text"
|
628 |
+
if "this is horrible" in original_prompt:
|
629 |
+
text_to_analyze = "this is horrible"
|
630 |
+
elif "awful" in original_prompt:
|
631 |
+
text_to_analyze = "awful"
|
632 |
+
|
633 |
+
return f"""Thoughts: I need to analyze the sentiment of the given text using the sentiment_analysis tool.
|
634 |
+
Code:
|
635 |
+
```py
|
636 |
+
text = "{text_to_analyze}"
|
637 |
+
result = sentiment_analysis(text=text)
|
638 |
+
final_answer(result)
|
639 |
+
```<end_code>"""
|
640 |
+
|
641 |
+
if "```" in response and ("Thoughts:" in response or "Code:" in response):
|
642 |
+
return response
|
643 |
+
|
644 |
+
clean_response = response.replace('"', '\\"').replace('\n', '\\n')
|
645 |
+
return f"""Thoughts: Processing the user's request.
|
646 |
+
Code:
|
647 |
+
```py
|
648 |
+
result = "{clean_response}"
|
649 |
+
final_answer(result)
|
650 |
+
```<end_code>"""
|
651 |
+
|
652 |
+
def _create_fallback_response(self, prompt: str, error_msg: str) -> str:
|
653 |
+
"""Create a valid fallback response when the model fails."""
|
654 |
+
return f"""Thoughts: The AI service is experiencing issues, providing a fallback response.
|
655 |
+
Code:
|
656 |
+
```py
|
657 |
+
error_message = "I apologize, but the AI service is temporarily experiencing high load. Please try again in a moment."
|
658 |
+
final_answer(error_message)
|
659 |
+
```<end_code>"""
|
660 |
+
|
661 |
+
class TokenUsage:
|
662 |
+
def __init__(self, input_tokens: int = 0, output_tokens: int = 0):
|
663 |
+
self.input_tokens = input_tokens
|
664 |
+
self.output_tokens = output_tokens
|
665 |
+
self.total_tokens = input_tokens + output_tokens
|
666 |
+
self.prompt_tokens = input_tokens
|
667 |
+
self.completion_tokens = output_tokens
|
668 |
+
|
669 |
+
class ModelResponse:
|
670 |
+
def __init__(self, content: str, prompt: str = ""):
|
671 |
+
self.content = content
|
672 |
+
self.text = content
|
673 |
+
estimated_input_tokens = len(prompt.split()) if prompt else 0
|
674 |
+
estimated_output_tokens = len(content.split()) if content else 0
|
675 |
+
self.token_usage = TokenUsage(estimated_input_tokens, estimated_output_tokens)
|
676 |
+
self.finish_reason = 'stop'
|
677 |
+
self.model = 'local-inference'
|
678 |
+
|
679 |
+
def __str__(self):
|
680 |
+
return self.content
|
681 |
+
|
682 |
+
# Global variables
|
683 |
+
_mcp_client = None
|
684 |
+
_tools = None
|
685 |
+
_model = None
|
686 |
+
_agent = None
|
687 |
+
_initialized = False
|
688 |
+
_initialization_lock = threading.Lock()
|
689 |
+
|
690 |
+
def initialize_agent():
|
691 |
+
"""Initialize the agent components with Hugging Face Spaces MCP servers."""
|
692 |
+
global _mcp_client, _tools, _model, _agent, _initialized
|
693 |
+
|
694 |
+
with _initialization_lock:
|
695 |
+
if _initialized:
|
696 |
+
skip_step = track_workflow_step("agent_init_skip", "Agent already initialized - using cached instance")
|
697 |
+
complete_workflow_step(skip_step, "completed", details={"optimization": "session_persistence"})
|
698 |
+
return
|
699 |
+
|
700 |
+
try:
|
701 |
+
print("Initializing MCP agent...")
|
702 |
+
|
703 |
+
agent_init_step = track_workflow_step("agent_init", "Initializing MCP agent components")
|
704 |
+
|
705 |
+
# Get clients for Hugging Face Spaces servers
|
706 |
+
all_tools = []
|
707 |
+
tool_names = set()
|
708 |
+
|
709 |
+
# Semantic Search & Keywords server
|
710 |
+
try:
|
711 |
+
semantic_client = get_mcp_client("https://nlarchive-mcp-semantic-keywords.hf.space/gradio_api/mcp/sse")
|
712 |
+
semantic_tools = semantic_client.get_tools()
|
713 |
+
for tool in semantic_tools:
|
714 |
+
if tool.name not in tool_names:
|
715 |
+
all_tools.append(tool)
|
716 |
+
tool_names.add(tool.name)
|
717 |
+
print(f"Connected to semantic server: {len(semantic_tools)} tools - {[t.name for t in semantic_tools]}")
|
718 |
+
except Exception as e:
|
719 |
+
print(f"WARNING: Semantic server unavailable: {e}")
|
720 |
+
|
721 |
+
# Token Counter server
|
722 |
+
try:
|
723 |
+
token_client = get_mcp_client("https://nlarchive-mcp-gr-token-counter.hf.space/gradio_api/mcp/sse")
|
724 |
+
token_tools = token_client.get_tools()
|
725 |
+
for tool in token_tools:
|
726 |
+
if tool.name not in tool_names:
|
727 |
+
all_tools.append(tool)
|
728 |
+
tool_names.add(tool.name)
|
729 |
+
print(f"Connected to token counter server: {len(token_tools)} tools - {[t.name for t in token_tools]}")
|
730 |
+
except Exception as e:
|
731 |
+
print(f"WARNING: Token counter server unavailable: {e}")
|
732 |
+
|
733 |
+
# Sentiment Analysis server
|
734 |
+
try:
|
735 |
+
sentiment_client = get_mcp_client("https://nlarchive-mcp-sentiment.hf.space/gradio_api/mcp/sse")
|
736 |
+
sentiment_tools = sentiment_client.get_tools()
|
737 |
+
for tool in sentiment_tools:
|
738 |
+
if tool.name not in tool_names:
|
739 |
+
all_tools.append(tool)
|
740 |
+
tool_names.add(tool.name)
|
741 |
+
print(f"Connected to sentiment analysis server: {len(sentiment_tools)} tools - {[t.name for t in sentiment_tools]}")
|
742 |
+
except Exception as e:
|
743 |
+
print(f"WARNING: Sentiment analysis server unavailable: {e}")
|
744 |
+
|
745 |
+
_tools = all_tools
|
746 |
+
_model = get_global_model()
|
747 |
+
|
748 |
+
# Create agent with unique tools only
|
749 |
+
_agent = CodeAgent(tools=_tools, model=_model)
|
750 |
+
|
751 |
+
complete_workflow_step(agent_init_step, "completed", details={
|
752 |
+
"tools_count": len(_tools),
|
753 |
+
"unique_tool_names": list(tool_names),
|
754 |
+
"servers_connected": 3
|
755 |
+
})
|
756 |
+
|
757 |
+
_initialized = True
|
758 |
+
print(f"Agent initialized with {len(_tools)} unique tools: {list(tool_names)}")
|
759 |
+
|
760 |
+
except Exception as e:
|
761 |
+
print(f"Agent initialization failed: {e}")
|
762 |
+
_model = get_global_model()
|
763 |
+
_agent = CodeAgent(tools=[], model=_model)
|
764 |
+
_initialized = True
|
765 |
+
print("Agent initialized in fallback mode")
|
766 |
+
|
767 |
+
def is_agent_initialized() -> bool:
|
768 |
+
"""Check if the agent is initialized."""
|
769 |
+
return _initialized
|
770 |
+
|
771 |
+
def run_agent(message: str) -> str:
|
772 |
+
"""Send message through the agent with comprehensive tracking."""
|
773 |
+
if not _initialized:
|
774 |
+
initialize_agent()
|
775 |
+
if _agent is None:
|
776 |
+
raise RuntimeError("Agent not properly initialized")
|
777 |
+
|
778 |
+
# Track agent processing
|
779 |
+
process_step = track_workflow_step("agent_process", f"Processing: {message}")
|
780 |
+
|
781 |
+
try:
|
782 |
+
# Enhanced tool tracking
|
783 |
+
tool_step: Optional[str] = None
|
784 |
+
detected_tools = []
|
785 |
+
|
786 |
+
# Detect potential tool usage
|
787 |
+
if any(keyword in message.lower() for keyword in ['sentiment', 'analyze', 'feeling']):
|
788 |
+
detected_tools.append('sentiment_analysis')
|
789 |
+
if any(keyword in message.lower() for keyword in ['token', 'count']):
|
790 |
+
detected_tools.extend(['count_tokens_openai_gpt4', 'count_tokens_bert_family'])
|
791 |
+
if any(keyword in message.lower() for keyword in ['semantic', 'similar', 'keyword']):
|
792 |
+
detected_tools.extend(['semantic_similarity', 'extract_semantic_keywords'])
|
793 |
+
|
794 |
+
if detected_tools:
|
795 |
+
tool_step = track_communication("agent", "mcp_server", "tool_call",
|
796 |
+
f"Executing tools {detected_tools} for: {message[:50]}...")
|
797 |
+
|
798 |
+
result = _agent.run(message)
|
799 |
+
|
800 |
+
# Complete tool step if it was tracked
|
801 |
+
if tool_step is not None:
|
802 |
+
complete_workflow_step(tool_step, "completed", details={
|
803 |
+
"result": str(result)[:100],
|
804 |
+
"detected_tools": detected_tools
|
805 |
+
})
|
806 |
+
|
807 |
+
complete_workflow_step(process_step, "completed", details={
|
808 |
+
"result_length": len(str(result)),
|
809 |
+
"detected_tools": detected_tools
|
810 |
+
})
|
811 |
+
|
812 |
+
return str(result)
|
813 |
+
|
814 |
+
except Exception as e:
|
815 |
+
error_msg = str(e)
|
816 |
+
print(f"Agent execution error: {error_msg}")
|
817 |
+
|
818 |
+
complete_workflow_step(process_step, "error", details={"error": error_msg})
|
819 |
+
|
820 |
+
# Enhanced error responses
|
821 |
+
if "503" in error_msg or "overloaded" in error_msg.lower():
|
822 |
+
return "I apologize, but the AI service is currently experiencing high demand. Please try again in a few moments."
|
823 |
+
elif "rate limit" in error_msg.lower():
|
824 |
+
return "The service is currently rate-limited. Please wait a moment before trying again."
|
825 |
+
elif "event loop" in error_msg.lower():
|
826 |
+
return "There was an async processing issue. The system is recovering. Please try again."
|
827 |
+
else:
|
828 |
+
return "I encountered an error while processing your request. Please try rephrasing your question or try again later."
|
829 |
+
|
830 |
+
def disconnect():
|
831 |
+
"""Cleanly disconnect connections with global pool management."""
|
832 |
+
global _mcp_client, _initialized
|
833 |
+
disconnect_step = track_workflow_step("agent_disconnect", "Disconnecting MCP client")
|
834 |
+
|
835 |
+
try:
|
836 |
+
# Phase 2 Optimization: Preserve global connections for reuse
|
837 |
+
with _global_connection_lock:
|
838 |
+
preserved_connections = 0
|
839 |
+
for url, client in _global_connection_pool.items():
|
840 |
+
try:
|
841 |
+
# Keep connections alive but mark as idle
|
842 |
+
if hasattr(client, '_last_used'):
|
843 |
+
client._last_used = time.time()
|
844 |
+
preserved_connections += 1
|
845 |
+
except:
|
846 |
+
pass
|
847 |
+
|
848 |
+
complete_workflow_step(disconnect_step, "completed", details={
|
849 |
+
"preserved_connections": preserved_connections,
|
850 |
+
"optimization": "connection_persistence"
|
851 |
+
})
|
852 |
+
except Exception as e:
|
853 |
+
complete_workflow_step(disconnect_step, "error", details={"error": str(e)})
|
854 |
+
finally:
|
855 |
+
# Don't reset global state - preserve for next session
|
856 |
+
_initialized = False
|
857 |
+
|
858 |
+
def initialize_session():
|
859 |
+
"""Initialize the persistent session - alias for initialize_agent."""
|
860 |
+
initialize_agent()
|
861 |
+
|
862 |
+
def is_session_initialized() -> bool:
|
863 |
+
"""Check if the persistent session is initialized - alias for is_agent_initialized."""
|
864 |
+
return is_agent_initialized()
|
865 |
+
|
866 |
+
# Make sure these are exported for imports
|
867 |
+
__all__ = [
|
868 |
+
'run_agent', 'initialize_agent', 'is_agent_initialized', 'disconnect',
|
869 |
+
'initialize_session', 'is_session_initialized',
|
870 |
+
'get_mcp_client', 'get_global_model', 'reset_global_state',
|
871 |
+
'_global_tools_cache', '_global_connection_pool', '_global_model_instance',
|
872 |
+
'_global_connection_lock', '_global_model_lock'
|
873 |
+
]
|
874 |
+
|
875 |
+
# Register cleanup function
|
876 |
+
def cleanup_global_resources():
|
877 |
+
"""Cleanup function for graceful shutdown."""
|
878 |
+
global _global_connection_pool, _event_loop_manager, _global_connection_lock, _event_loop_lock
|
879 |
+
|
880 |
+
print("Cleaning up global resources...")
|
881 |
+
|
882 |
+
with _global_connection_lock:
|
883 |
+
for client in _global_connection_pool.values():
|
884 |
+
try:
|
885 |
+
client.disconnect()
|
886 |
+
except:
|
887 |
+
pass
|
888 |
+
_global_connection_pool.clear()
|
889 |
+
|
890 |
+
# Shutdown event loop manager
|
891 |
+
with _event_loop_lock:
|
892 |
+
if _event_loop_manager:
|
893 |
+
try:
|
894 |
+
_event_loop_manager.shutdown()
|
895 |
+
except:
|
896 |
+
pass
|
897 |
+
_event_loop_manager = None
|
898 |
+
|
899 |
+
# Register cleanup on exit
|
900 |
+
atexit.register(cleanup_global_resources)
|