NLarchive commited on
Commit
42dd27a
·
verified ·
1 Parent(s): 93cb588

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +900 -0
agent.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import threading
5
+ import time
6
+ import re
7
+ import atexit
8
+ from contextlib import asynccontextmanager
9
+ from typing import Any, Optional, List, Dict, Tuple, Callable
10
+ from smolagents import CodeAgent, MCPClient
11
+ from smolagents.models import Model
12
+ from inference import initialize, generate_content
13
+ from workflow_vizualizer import track_workflow_step, track_communication, complete_workflow_step
14
+
15
+ # Global session management
16
+ _session_initialized = False
17
+ _session_lock = threading.Lock()
18
+ _session_start_time = None
19
+
20
+ # Enhanced global caching for Phase 2 optimizations with async support
21
+ _global_tools_cache = {}
22
+ _global_tools_timestamp = None
23
+ _global_model_instance = None
24
+ _global_model_lock = threading.Lock()
25
+ _global_connection_pool = {}
26
+ _global_connection_lock = threading.Lock()
27
+
28
+ # Managed event loop system
29
+ _managed_event_loop = None
30
+ _event_loop_lock = threading.Lock()
31
+ _event_loop_manager = None # Global event loop manager instance
32
+
33
+
34
+ @asynccontextmanager
35
+ async def managed_event_loop():
36
+ """Proper async context manager for event loop lifecycle."""
37
+ global _managed_event_loop
38
+
39
+ try:
40
+ # Create new event loop if needed
41
+ if _managed_event_loop is None or _managed_event_loop.is_closed():
42
+ _managed_event_loop = asyncio.new_event_loop()
43
+ asyncio.set_event_loop(_managed_event_loop)
44
+
45
+ print("✅ Event loop initialized and set as current")
46
+ yield _managed_event_loop
47
+
48
+ except Exception as e:
49
+ print(f"❌ Event loop error: {e}")
50
+ raise
51
+ finally:
52
+ # Don't close the loop here - let it be managed at a higher level
53
+ pass
54
+
55
+
56
+ async def safe_async_call(coroutine, timeout=30):
57
+ """Safely execute async calls with proper error handling."""
58
+ try:
59
+ return await asyncio.wait_for(coroutine, timeout=timeout)
60
+ except asyncio.TimeoutError:
61
+ print(f"⏱️ Async call timed out after {timeout}s")
62
+ raise
63
+ except RuntimeError as e:
64
+ if "Event loop is closed" in str(e):
65
+ print("🔄 Event loop closed - attempting to create new one")
66
+ # Create new event loop and retry
67
+ loop = asyncio.new_event_loop()
68
+ asyncio.set_event_loop(loop)
69
+ return await asyncio.wait_for(coroutine, timeout=timeout)
70
+ raise
71
+
72
+
73
+ class AsyncEventLoopManager:
74
+ def __init__(self):
75
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
76
+ self._thread: Optional[threading.Thread] = None
77
+ self._loop = asyncio.new_event_loop()
78
+ self._thread = threading.Thread(target=self._run_loop, daemon=True)
79
+ self._thread.start()
80
+ print("AsyncEventLoopManager: Initialized and thread started.")
81
+
82
+ def _run_loop(self):
83
+ if self._loop is None:
84
+ print("AsyncEventLoopManager: _run_loop called but loop is None.")
85
+ return
86
+ asyncio.set_event_loop(self._loop)
87
+ try:
88
+ print("AsyncEventLoopManager: Event loop running.")
89
+ self._loop.run_forever()
90
+ except Exception as e:
91
+ print(f"AsyncEventLoopManager: Exception in event loop: {e}")
92
+ finally:
93
+ # Ensure the loop is stopped if it was running.
94
+ # The actual closing is handled by the shutdown() method.
95
+ if self._loop and self._loop.is_running():
96
+ self._loop.stop()
97
+ print("AsyncEventLoopManager: Event loop stopped in _run_loop finally.")
98
+
99
+ def run_async(self, coro):
100
+ """Run a coroutine in the event loop from another thread."""
101
+ coro_name = getattr(coro, '__name__', str(coro))
102
+ if self._loop is None:
103
+ print(f"AsyncEventLoopManager: Loop object is None. Cannot run coroutine {coro_name}.")
104
+ raise RuntimeError("Event loop manager is not properly initialized (loop missing).")
105
+
106
+ if self._loop.is_closed():
107
+ print(f"AsyncEventLoopManager: Loop is CLOSED. Cannot schedule coroutine {coro_name}.")
108
+ raise RuntimeError(f"Event loop is closed. Cannot run {coro_name}.")
109
+
110
+ if self._thread is None or not self._thread.is_alive():
111
+ print(f"AsyncEventLoopManager: Event loop thread is not alive or None. Cannot run coroutine {coro_name}.")
112
+ raise RuntimeError("Event loop thread is not alive or None.")
113
+
114
+ try:
115
+ future = asyncio.run_coroutine_threadsafe(coro, self._loop)
116
+ return future.result(timeout=30) # Assuming a 30s timeout
117
+ except RuntimeError as e:
118
+ print(f"AsyncEventLoopManager: RuntimeError during run_coroutine_threadsafe for {coro_name}: {e}")
119
+ raise
120
+ except asyncio.TimeoutError:
121
+ print(f"AsyncEventLoopManager: Timeout waiting for coroutine {coro_name} result.")
122
+ raise
123
+ except Exception as e:
124
+ print(f"AsyncEventLoopManager: Error submitting coroutine {coro_name}: {e}")
125
+ raise
126
+
127
+ def shutdown(self):
128
+ """Stop and close the event loop."""
129
+ print("AsyncEventLoopManager: Shutdown initiated.")
130
+ if self._loop and not self._loop.is_closed():
131
+ if self._loop.is_running():
132
+ self._loop.call_soon_threadsafe(self._loop.stop)
133
+ print("AsyncEventLoopManager: Stop signal sent to running event loop.")
134
+ else:
135
+ print("AsyncEventLoopManager: Event loop was not running, but attempting to stop.")
136
+ # If not running, stop might not be necessary or might error,
137
+ # but call_soon_threadsafe should be safe.
138
+ try:
139
+ self._loop.call_soon_threadsafe(self._loop.stop)
140
+ except RuntimeError as e:
141
+ print(f"AsyncEventLoopManager: Info - could not send stop to non-running loop: {e}")
142
+
143
+ if self._thread and self._thread.is_alive():
144
+ self._thread.join(timeout=10)
145
+ if self._thread.is_alive():
146
+ print("AsyncEventLoopManager: Thread did not join in time during shutdown.")
147
+ else:
148
+ print("AsyncEventLoopManager: Thread joined.")
149
+ else:
150
+ print("AsyncEventLoopManager: Thread already stopped, not initialized, or None at shutdown.")
151
+
152
+ # Explicitly close the loop here after the thread has finished.
153
+ if self._loop and not self._loop.is_closed():
154
+ try:
155
+ # Ensure all tasks are cancelled before closing
156
+ # Gather all tasks:
157
+ if sys.version_info >= (3, 7): # gather works on all tasks in 3.7+
158
+ tasks = asyncio.all_tasks(self._loop)
159
+ for task in tasks:
160
+ task.cancel()
161
+ # Wait for tasks to cancel - this should be done within the loop's thread ideally
162
+ # but since we are shutting down from outside, this is a best effort.
163
+ # self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
164
+ self._loop.close()
165
+ print("AsyncEventLoopManager: Event loop closed in shutdown.")
166
+ except Exception as e:
167
+ print(f"AsyncEventLoopManager: Exception while closing loop: {e}")
168
+ elif self._loop and self._loop.is_closed():
169
+ print("AsyncEventLoopManager: Event loop was already closed.")
170
+ else:
171
+ print("AsyncEventLoopManager: No loop to close or loop was None.")
172
+
173
+ self._loop = None
174
+ self._thread = None
175
+ print("AsyncEventLoopManager: Shutdown process complete.")
176
+
177
+ def get_event_loop_manager():
178
+ """Get or create the global event loop manager."""
179
+ global _event_loop_manager
180
+
181
+ with _event_loop_lock:
182
+ # Check if manager exists and its loop and thread are valid
183
+ manager_valid = False
184
+ if _event_loop_manager is not None:
185
+ # More robust check: loop exists, is not closed, thread exists and is alive
186
+ if _event_loop_manager._loop is not None and \
187
+ not _event_loop_manager._loop.is_closed() and \
188
+ _event_loop_manager._thread is not None and \
189
+ _event_loop_manager._thread.is_alive():
190
+ manager_valid = True
191
+ else:
192
+ print("get_event_loop_manager: Existing manager found but its loop or thread is invalid. Recreating.")
193
+ try:
194
+ _event_loop_manager.shutdown() # Attempt to clean up the old one
195
+ except Exception as e:
196
+ print(f"get_event_loop_manager: Error shutting down invalid manager: {e}")
197
+ _event_loop_manager = None # Ensure it's None so a new one is created
198
+
199
+ if _event_loop_manager is None: # Covers both initial creation and recreation
200
+ print("get_event_loop_manager: Creating new AsyncEventLoopManager instance.")
201
+ _event_loop_manager = AsyncEventLoopManager()
202
+ else:
203
+ print("get_event_loop_manager: Reusing existing valid AsyncEventLoopManager instance.")
204
+ return _event_loop_manager
205
+
206
+ def shutdown_event_loop_manager():
207
+ """Shutdown the global event loop manager."""
208
+ global _event_loop_manager
209
+ with _event_loop_lock:
210
+ if _event_loop_manager is not None:
211
+ print("shutdown_event_loop_manager: Shutting down global event loop manager.")
212
+ try:
213
+ _event_loop_manager.shutdown()
214
+ except Exception as e:
215
+ print(f"shutdown_event_loop_manager: Error during shutdown: {e}")
216
+ finally:
217
+ _event_loop_manager = None
218
+ else:
219
+ print("shutdown_event_loop_manager: No active event loop manager to shut down.")
220
+
221
+ class AsyncMCPClientWrapper:
222
+ """Wrapper for async MCP client operations."""
223
+
224
+ def __init__(self, url: str):
225
+ self.url = url
226
+ self._mcp_client = None
227
+ self._tools = None
228
+ self._tools_cache_time = None
229
+ self._cache_ttl = 300 # 5 minutes cache
230
+ self._connected = False
231
+
232
+ async def ensure_connected(self):
233
+ """Ensure async connection is established."""
234
+ if not self._connected or self._mcp_client is None:
235
+ try:
236
+ # Create MCP client with SSE transport for Gradio
237
+ self._mcp_client = MCPClient({"url": self.url, "transport": "sse"})
238
+ # Attempt a lightweight operation to confirm connectivity, e.g., get_tools or a custom ping
239
+ # For now, we assume MCPClient constructor success implies basic connectivity.
240
+ # If get_tools is lightweight enough, it can be called here.
241
+ # await self._mcp_client.get_tools() # Example, if get_tools is async and suitable
242
+ self._connected = True
243
+ print(f"✅ Connected to MCP server: {self.url}")
244
+ except Exception as e:
245
+ self._connected = False
246
+ print(f"❌ Failed to connect to {self.url}: {e}")
247
+ raise
248
+
249
+ async def get_tools(self):
250
+ """Get tools asynchronously."""
251
+ current_time = time.time()
252
+
253
+ # Check instance cache
254
+ if (self._tools is not None and
255
+ self._tools_cache_time is not None and
256
+ current_time - self._tools_cache_time < self._cache_ttl):
257
+ return self._tools
258
+
259
+ # Fetch fresh tools
260
+ await self.ensure_connected() # Ensures client is connected
261
+
262
+ if self._mcp_client is None: # Should be caught by ensure_connected, but as a safeguard
263
+ raise RuntimeError("MCP client not connected")
264
+
265
+ try:
266
+ # Assuming MCPClient.get_tools() is a synchronous method based on original structure
267
+ # If it were async, it would be `await self._mcp_client.get_tools()`
268
+ self._tools = self._mcp_client.get_tools()
269
+ self._tools_cache_time = current_time
270
+ tool_names = [tool.name for tool in self._tools] if self._tools else []
271
+ print(f"🔧 Fetched {len(tool_names)} tools from {self.url}: {tool_names}")
272
+
273
+ return self._tools
274
+ except Exception as e:
275
+ print(f"❌ Error fetching tools from {self.url}: {e}")
276
+ # Potentially reset connection status if error indicates a connection problem
277
+ # self._connected = False
278
+ raise
279
+
280
+ async def disconnect(self):
281
+ """Gracefully disconnect."""
282
+ if self._mcp_client and self._connected:
283
+ try:
284
+ # Assuming MCPClient.disconnect() is synchronous
285
+ # If it were async, it would be `await self._mcp_client.disconnect()`
286
+ self._mcp_client.disconnect()
287
+ except Exception as e:
288
+ print(f"Error during MCPClient disconnect for {self.url}: {e}")
289
+ # Log error but continue to mark as disconnected
290
+ pass # Fall through to set _connected = False
291
+ self._connected = False
292
+ self._mcp_client = None
293
+ print(f"🔌 Disconnected from MCP server: {self.url}")
294
+
295
+ class AsyncPersistentMCPClient:
296
+ """Async-aware persistent MCP client that survives multiple requests."""
297
+
298
+ def __init__(self, url: str):
299
+ self.url = url
300
+ self._wrapper = AsyncMCPClientWrapper(url)
301
+ self._loop_manager = None
302
+
303
+ def ensure_connected(self):
304
+ """Sync wrapper for async connection."""
305
+ if self._loop_manager is None:
306
+ self._loop_manager = get_event_loop_manager()
307
+
308
+ conn_step = track_communication("agent", "mcp_client", "connection_ensure", f"Ensuring connection to {self.url}")
309
+ try:
310
+ # Ensure we have a valid loop manager
311
+ if self._loop_manager is None:
312
+ self._loop_manager = get_event_loop_manager()
313
+
314
+ # Additional safety check
315
+ if self._loop_manager is None:
316
+ raise RuntimeError("Failed to create event loop manager")
317
+
318
+ # Pass the coroutine object itself, not its result
319
+ self._loop_manager.run_async(self._wrapper.ensure_connected())
320
+ complete_workflow_step(conn_step, "completed", details={"url": self.url})
321
+ except Exception as e:
322
+ complete_workflow_step(conn_step, "error", details={"error": str(e)})
323
+ raise
324
+
325
+ def get_client(self):
326
+ """Get the underlying MCP client."""
327
+ self.ensure_connected()
328
+ return self._wrapper._mcp_client
329
+
330
+ def get_tools(self):
331
+ """Get tools with enhanced caching and async support."""
332
+ global _global_tools_cache, _global_tools_timestamp
333
+ current_time = time.time()
334
+
335
+ if self._loop_manager is None:
336
+ self._loop_manager = get_event_loop_manager()
337
+
338
+ # Phase 2 Optimization: Check server-specific global cache first
339
+ with _global_connection_lock:
340
+ server_cache_key = self.url
341
+ server_cache = _global_tools_cache.get(server_cache_key, {})
342
+
343
+ if (server_cache and _global_tools_timestamp and
344
+ current_time - _global_tools_timestamp < 300):
345
+ # Track global cache hit
346
+ cache_step = track_communication("mcp_client", "mcp_server", "cache_hit_global", f"Using global cached tools for {self.url}")
347
+ complete_workflow_step(cache_step, "completed", details={
348
+ "tools": list(server_cache.keys()),
349
+ "cache_type": "global_server_specific",
350
+ "server_url": self.url,
351
+ "cache_age": current_time - _global_tools_timestamp
352
+ })
353
+ return list(server_cache.values())
354
+
355
+ # Fetch fresh tools using async
356
+ tools_step = track_communication("mcp_client", "mcp_server", "get_tools", f"Fetching tools from {self.url} (cache refresh)")
357
+ try:
358
+ # Ensure we have a valid loop manager
359
+ if self._loop_manager is None:
360
+ self._loop_manager = get_event_loop_manager()
361
+
362
+ # Additional safety check
363
+ if self._loop_manager is None:
364
+ raise RuntimeError("Failed to create event loop manager")
365
+
366
+ # Pass the coroutine object itself
367
+ tools = self._loop_manager.run_async(self._wrapper.get_tools())
368
+
369
+ # Update global cache
370
+ with _global_connection_lock:
371
+ if tools:
372
+ if server_cache_key not in _global_tools_cache:
373
+ _global_tools_cache[server_cache_key] = {}
374
+
375
+ _global_tools_cache[server_cache_key] = {tool.name: tool for tool in tools}
376
+ _global_tools_timestamp = current_time
377
+
378
+ total_tools = sum(len(server_tools) for server_tools in _global_tools_cache.values())
379
+ print(f"🔄 Global tools cache updated for {self.url}: {len(tools)} tools")
380
+ print(f" Total cached tools across all servers: {total_tools}")
381
+
382
+ tool_names = [tool.name for tool in tools] if tools else []
383
+ complete_workflow_step(tools_step, "completed", details={
384
+ "tools": tool_names,
385
+ "count": len(tool_names),
386
+ "server_url": self.url,
387
+ "cache_status": "refreshed_server_specific",
388
+ "global_cache_servers": len(_global_tools_cache)
389
+ })
390
+ return tools
391
+
392
+ except Exception as e:
393
+ complete_workflow_step(tools_step, "error", details={"error": str(e), "server_url": self.url})
394
+ raise
395
+
396
+ def disconnect(self):
397
+ """Gracefully disconnect."""
398
+ if self._loop_manager and self._wrapper:
399
+ try:
400
+ # Ensure we have a valid loop manager
401
+ if self._loop_manager is None:
402
+ self._loop_manager = get_event_loop_manager()
403
+
404
+ # Additional safety check
405
+ if self._loop_manager is None:
406
+ raise RuntimeError("Failed to create event loop manager")
407
+
408
+ # Pass the coroutine object itself
409
+ self._loop_manager.run_async(self._wrapper.disconnect())
410
+ except RuntimeError as e:
411
+ # Handle cases where the loop might already be closed or unable to run tasks
412
+ print(f"AsyncPersistentMCPClient: Error running disconnect for {self.url} in async loop: {e}")
413
+ except Exception as e:
414
+ print(f"AsyncPersistentMCPClient: General error during disconnect for {self.url}: {e}")
415
+
416
+ def get_mcp_client(url: str = "https://NLarchive-Agent-client-multi-mcp-SKT.hf.space/gradio_api/mcp/sse") -> AsyncPersistentMCPClient:
417
+ """Get or create an MCP client with enhanced global connection pooling."""
418
+ # Phase 2 Optimization: Use global connection pool
419
+ with _global_connection_lock:
420
+ if url not in _global_connection_pool:
421
+ conn_step = track_communication("agent", "mcp_client", "connection_create", f"Creating new global connection to {url}")
422
+ _global_connection_pool[url] = AsyncPersistentMCPClient(url)
423
+ complete_workflow_step(conn_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
424
+ else:
425
+ # Track connection reuse
426
+ reuse_step = track_communication("agent", "mcp_client", "connection_reuse", f"Reusing global connection to {url}")
427
+ complete_workflow_step(reuse_step, "completed", details={"url": url, "pool_size": len(_global_connection_pool)})
428
+
429
+ return _global_connection_pool[url]
430
+
431
+ def get_global_model() -> 'CachedLocalInferenceModel':
432
+ """Get or create global model instance for Phase 2 optimization."""
433
+ global _global_model_instance
434
+
435
+ with _global_model_lock:
436
+ if _global_model_instance is None:
437
+ model_step = track_workflow_step("model_init_global", "Initializing global model instance")
438
+
439
+ # CRITICAL FIX: Create and assign BEFORE initialization
440
+ _global_model_instance = CachedLocalInferenceModel()
441
+
442
+ # Now initialize the model
443
+ try:
444
+ _global_model_instance.ensure_initialized()
445
+ complete_workflow_step(model_step, "completed", details={"model_type": "global_cached"})
446
+ print(f"🤖 Global model instance created and initialized")
447
+ except Exception as e:
448
+ # If initialization fails, reset global instance
449
+ _global_model_instance = None
450
+ complete_workflow_step(model_step, "error", details={"error": str(e)})
451
+ raise
452
+ else:
453
+ # Track model reuse
454
+ reuse_step = track_workflow_step("model_reuse", "Reusing global model instance")
455
+ complete_workflow_step(reuse_step, "completed", details={"model_type": "global_cached"})
456
+
457
+ return _global_model_instance
458
+
459
+ def reset_global_state():
460
+ """Reset global state for testing purposes with server-specific cache awareness."""
461
+ global _global_tools_cache, _global_tools_timestamp, _global_model_instance, _global_connection_pool, _event_loop_manager
462
+
463
+ with _global_connection_lock:
464
+ # FIXED: Clear server-specific cache structure (don't rebind!)
465
+ _global_tools_cache.clear() # Now clears {url: {tool_name: tool}} structure
466
+ _global_tools_timestamp = None
467
+
468
+ # Disconnect all connections but keep pool structure
469
+ for client in _global_connection_pool.values():
470
+ try:
471
+ client.disconnect()
472
+ except:
473
+ pass
474
+
475
+ with _global_model_lock:
476
+ # Don't reset model instance - it should persist
477
+ pass
478
+
479
+ print("🔄 Global state reset for testing (server-specific cache cleared)")
480
+
481
+ # Enhanced LocalInferenceModel with workflow tracking
482
+ class CachedLocalInferenceModel(Model):
483
+ """Model with enhanced caching and session persistence."""
484
+
485
+ def __init__(self):
486
+ super().__init__()
487
+ self._response_cache = {}
488
+ self._cache_hits = 0
489
+ self._cache_misses = 0
490
+ self._model_ready = False
491
+
492
+ def ensure_initialized(self):
493
+ """Lazy initialization of the model."""
494
+ if not self._model_ready:
495
+ init_step = track_workflow_step("model_init", "Initializing inference model (lazy)")
496
+ try:
497
+ initialize()
498
+ self._model_ready = True
499
+ complete_workflow_step(init_step, "completed")
500
+ except Exception as e:
501
+ complete_workflow_step(init_step, "error", details={"error": str(e)})
502
+ raise
503
+
504
+ def generate(self, messages: Any, **kwargs: Any) -> Any:
505
+ self.ensure_initialized()
506
+
507
+ prompt = self._format_messages(messages)
508
+
509
+ # Enhanced cache with hash-based lookup
510
+ cache_key = hash(prompt)
511
+ if cache_key in self._response_cache:
512
+ self._cache_hits += 1
513
+ cached_response = self._response_cache[cache_key]
514
+
515
+ # Track cache hit
516
+ cache_step = track_communication("agent", "llm_service", "cache_hit", "Using cached response")
517
+ complete_workflow_step(cache_step, "completed", details={
518
+ "cache_hits": self._cache_hits,
519
+ "cache_misses": self._cache_misses,
520
+ "cache_ratio": self._cache_hits / (self._cache_hits + self._cache_misses)
521
+ })
522
+
523
+ return ModelResponse(cached_response.content, prompt)
524
+
525
+ self._cache_misses += 1
526
+
527
+ # Track LLM call
528
+ llm_step = track_communication("agent", "llm_service", "generate_request", "Generating new response")
529
+
530
+ try:
531
+ enhanced_prompt = self._enhance_prompt_for_tools(prompt)
532
+
533
+ response_text = generate_content(
534
+ prompt=enhanced_prompt,
535
+ model_name=kwargs.get('model_name'),
536
+ allow_fallbacks=True,
537
+ generation_config={
538
+ 'temperature': kwargs.get('temperature', 0.3),
539
+ 'max_output_tokens': kwargs.get('max_tokens', 512)
540
+ }
541
+ )
542
+
543
+ # Validate and fix response format
544
+ if not self._is_valid_code_response(response_text):
545
+ response_text = self._fix_response_format(response_text, prompt)
546
+
547
+ response = ModelResponse(str(response_text), prompt)
548
+
549
+ # Smart cache management (keep most recent 10 responses)
550
+ if len(self._response_cache) >= 10:
551
+ # Remove oldest entry (simple FIFO)
552
+ oldest_key = next(iter(self._response_cache))
553
+ del self._response_cache[oldest_key]
554
+
555
+ self._response_cache[cache_key] = response
556
+
557
+ complete_workflow_step(llm_step, "completed", details={
558
+ "cache_status": "new",
559
+ "input_tokens": response.token_usage.input_tokens,
560
+ "output_tokens": response.token_usage.output_tokens,
561
+ "model": response.model
562
+ })
563
+
564
+ return response
565
+
566
+ except Exception as e:
567
+ fallback_response = self._create_fallback_response(prompt, str(e))
568
+ complete_workflow_step(llm_step, "error", details={"error": str(e)})
569
+ return ModelResponse(fallback_response, prompt)
570
+
571
+ def _enhance_prompt_for_tools(self, prompt: str) -> str:
572
+ """Enhance the prompt with better tool usage examples."""
573
+ if "sentiment" in prompt.lower():
574
+ tool_example = """
575
+ IMPORTANT: When calling sentiment_analysis, use keyword arguments only:
576
+ Correct: sentiment_analysis(text="your text here")
577
+ Wrong: sentiment_analysis("your text here")
578
+
579
+ Example:
580
+ ```py
581
+ text = "this is horrible"
582
+ result = sentiment_analysis(text=text)
583
+ final_answer(result)
584
+ ```"""
585
+ return prompt + "\n" + tool_example
586
+ return prompt
587
+
588
+ def _format_messages(self, messages: Any) -> str:
589
+ """Convert messages to a single prompt string."""
590
+ if isinstance(messages, str):
591
+ return messages
592
+ elif isinstance(messages, list):
593
+ prompt_parts = []
594
+ for msg in messages:
595
+ if isinstance(msg, dict):
596
+ if 'content' in msg:
597
+ content = msg['content']
598
+ role = msg.get('role', 'user')
599
+ if isinstance(content, list):
600
+ text_parts = [part.get('text', '') for part in content if part.get('type') == 'text']
601
+ content = ' '.join(text_parts)
602
+ prompt_parts.append(f"{role}: {content}")
603
+ elif 'text' in msg:
604
+ prompt_parts.append(msg['text'])
605
+ elif hasattr(msg, 'content'):
606
+ prompt_parts.append(str(msg.content))
607
+ else:
608
+ prompt_parts.append(str(msg))
609
+ return '\n'.join(prompt_parts)
610
+ else:
611
+ return str(messages)
612
+
613
+ def _is_valid_code_response(self, response: str) -> bool:
614
+ """Check if response contains valid code block format."""
615
+ code_pattern = r'```(?:py|python)?\s*\n(.*?)\n```'
616
+ return bool(re.search(code_pattern, response, re.DOTALL))
617
+
618
+ def _fix_response_format(self, response: str, original_prompt: str) -> str:
619
+ """Try to fix response format to match expected pattern."""
620
+ # Attempt to remove or comment out "Thoughts:" if not in a code block already
621
+ # This is a common source of SyntaxError if the LLM includes it directly
622
+ if "Thoughts:" in response and not "```" in response.split("Thoughts:")[0]:
623
+ # If "Thoughts:" appears before any code block, comment it out
624
+ response = response.replace("Thoughts:", "# Thoughts:", 1)
625
+
626
+ if "sentiment" in original_prompt.lower():
627
+ text_to_analyze = "neutral text"
628
+ if "this is horrible" in original_prompt:
629
+ text_to_analyze = "this is horrible"
630
+ elif "awful" in original_prompt:
631
+ text_to_analyze = "awful"
632
+
633
+ return f"""Thoughts: I need to analyze the sentiment of the given text using the sentiment_analysis tool.
634
+ Code:
635
+ ```py
636
+ text = "{text_to_analyze}"
637
+ result = sentiment_analysis(text=text)
638
+ final_answer(result)
639
+ ```<end_code>"""
640
+
641
+ if "```" in response and ("Thoughts:" in response or "Code:" in response):
642
+ return response
643
+
644
+ clean_response = response.replace('"', '\\"').replace('\n', '\\n')
645
+ return f"""Thoughts: Processing the user's request.
646
+ Code:
647
+ ```py
648
+ result = "{clean_response}"
649
+ final_answer(result)
650
+ ```<end_code>"""
651
+
652
+ def _create_fallback_response(self, prompt: str, error_msg: str) -> str:
653
+ """Create a valid fallback response when the model fails."""
654
+ return f"""Thoughts: The AI service is experiencing issues, providing a fallback response.
655
+ Code:
656
+ ```py
657
+ error_message = "I apologize, but the AI service is temporarily experiencing high load. Please try again in a moment."
658
+ final_answer(error_message)
659
+ ```<end_code>"""
660
+
661
+ class TokenUsage:
662
+ def __init__(self, input_tokens: int = 0, output_tokens: int = 0):
663
+ self.input_tokens = input_tokens
664
+ self.output_tokens = output_tokens
665
+ self.total_tokens = input_tokens + output_tokens
666
+ self.prompt_tokens = input_tokens
667
+ self.completion_tokens = output_tokens
668
+
669
+ class ModelResponse:
670
+ def __init__(self, content: str, prompt: str = ""):
671
+ self.content = content
672
+ self.text = content
673
+ estimated_input_tokens = len(prompt.split()) if prompt else 0
674
+ estimated_output_tokens = len(content.split()) if content else 0
675
+ self.token_usage = TokenUsage(estimated_input_tokens, estimated_output_tokens)
676
+ self.finish_reason = 'stop'
677
+ self.model = 'local-inference'
678
+
679
+ def __str__(self):
680
+ return self.content
681
+
682
+ # Global variables
683
+ _mcp_client = None
684
+ _tools = None
685
+ _model = None
686
+ _agent = None
687
+ _initialized = False
688
+ _initialization_lock = threading.Lock()
689
+
690
+ def initialize_agent():
691
+ """Initialize the agent components with Hugging Face Spaces MCP servers."""
692
+ global _mcp_client, _tools, _model, _agent, _initialized
693
+
694
+ with _initialization_lock:
695
+ if _initialized:
696
+ skip_step = track_workflow_step("agent_init_skip", "Agent already initialized - using cached instance")
697
+ complete_workflow_step(skip_step, "completed", details={"optimization": "session_persistence"})
698
+ return
699
+
700
+ try:
701
+ print("Initializing MCP agent...")
702
+
703
+ agent_init_step = track_workflow_step("agent_init", "Initializing MCP agent components")
704
+
705
+ # Get clients for Hugging Face Spaces servers
706
+ all_tools = []
707
+ tool_names = set()
708
+
709
+ # Semantic Search & Keywords server
710
+ try:
711
+ semantic_client = get_mcp_client("https://nlarchive-mcp-semantic-keywords.hf.space/gradio_api/mcp/sse")
712
+ semantic_tools = semantic_client.get_tools()
713
+ for tool in semantic_tools:
714
+ if tool.name not in tool_names:
715
+ all_tools.append(tool)
716
+ tool_names.add(tool.name)
717
+ print(f"Connected to semantic server: {len(semantic_tools)} tools - {[t.name for t in semantic_tools]}")
718
+ except Exception as e:
719
+ print(f"WARNING: Semantic server unavailable: {e}")
720
+
721
+ # Token Counter server
722
+ try:
723
+ token_client = get_mcp_client("https://nlarchive-mcp-gr-token-counter.hf.space/gradio_api/mcp/sse")
724
+ token_tools = token_client.get_tools()
725
+ for tool in token_tools:
726
+ if tool.name not in tool_names:
727
+ all_tools.append(tool)
728
+ tool_names.add(tool.name)
729
+ print(f"Connected to token counter server: {len(token_tools)} tools - {[t.name for t in token_tools]}")
730
+ except Exception as e:
731
+ print(f"WARNING: Token counter server unavailable: {e}")
732
+
733
+ # Sentiment Analysis server
734
+ try:
735
+ sentiment_client = get_mcp_client("https://nlarchive-mcp-sentiment.hf.space/gradio_api/mcp/sse")
736
+ sentiment_tools = sentiment_client.get_tools()
737
+ for tool in sentiment_tools:
738
+ if tool.name not in tool_names:
739
+ all_tools.append(tool)
740
+ tool_names.add(tool.name)
741
+ print(f"Connected to sentiment analysis server: {len(sentiment_tools)} tools - {[t.name for t in sentiment_tools]}")
742
+ except Exception as e:
743
+ print(f"WARNING: Sentiment analysis server unavailable: {e}")
744
+
745
+ _tools = all_tools
746
+ _model = get_global_model()
747
+
748
+ # Create agent with unique tools only
749
+ _agent = CodeAgent(tools=_tools, model=_model)
750
+
751
+ complete_workflow_step(agent_init_step, "completed", details={
752
+ "tools_count": len(_tools),
753
+ "unique_tool_names": list(tool_names),
754
+ "servers_connected": 3
755
+ })
756
+
757
+ _initialized = True
758
+ print(f"Agent initialized with {len(_tools)} unique tools: {list(tool_names)}")
759
+
760
+ except Exception as e:
761
+ print(f"Agent initialization failed: {e}")
762
+ _model = get_global_model()
763
+ _agent = CodeAgent(tools=[], model=_model)
764
+ _initialized = True
765
+ print("Agent initialized in fallback mode")
766
+
767
+ def is_agent_initialized() -> bool:
768
+ """Check if the agent is initialized."""
769
+ return _initialized
770
+
771
+ def run_agent(message: str) -> str:
772
+ """Send message through the agent with comprehensive tracking."""
773
+ if not _initialized:
774
+ initialize_agent()
775
+ if _agent is None:
776
+ raise RuntimeError("Agent not properly initialized")
777
+
778
+ # Track agent processing
779
+ process_step = track_workflow_step("agent_process", f"Processing: {message}")
780
+
781
+ try:
782
+ # Enhanced tool tracking
783
+ tool_step: Optional[str] = None
784
+ detected_tools = []
785
+
786
+ # Detect potential tool usage
787
+ if any(keyword in message.lower() for keyword in ['sentiment', 'analyze', 'feeling']):
788
+ detected_tools.append('sentiment_analysis')
789
+ if any(keyword in message.lower() for keyword in ['token', 'count']):
790
+ detected_tools.extend(['count_tokens_openai_gpt4', 'count_tokens_bert_family'])
791
+ if any(keyword in message.lower() for keyword in ['semantic', 'similar', 'keyword']):
792
+ detected_tools.extend(['semantic_similarity', 'extract_semantic_keywords'])
793
+
794
+ if detected_tools:
795
+ tool_step = track_communication("agent", "mcp_server", "tool_call",
796
+ f"Executing tools {detected_tools} for: {message[:50]}...")
797
+
798
+ result = _agent.run(message)
799
+
800
+ # Complete tool step if it was tracked
801
+ if tool_step is not None:
802
+ complete_workflow_step(tool_step, "completed", details={
803
+ "result": str(result)[:100],
804
+ "detected_tools": detected_tools
805
+ })
806
+
807
+ complete_workflow_step(process_step, "completed", details={
808
+ "result_length": len(str(result)),
809
+ "detected_tools": detected_tools
810
+ })
811
+
812
+ return str(result)
813
+
814
+ except Exception as e:
815
+ error_msg = str(e)
816
+ print(f"Agent execution error: {error_msg}")
817
+
818
+ complete_workflow_step(process_step, "error", details={"error": error_msg})
819
+
820
+ # Enhanced error responses
821
+ if "503" in error_msg or "overloaded" in error_msg.lower():
822
+ return "I apologize, but the AI service is currently experiencing high demand. Please try again in a few moments."
823
+ elif "rate limit" in error_msg.lower():
824
+ return "The service is currently rate-limited. Please wait a moment before trying again."
825
+ elif "event loop" in error_msg.lower():
826
+ return "There was an async processing issue. The system is recovering. Please try again."
827
+ else:
828
+ return "I encountered an error while processing your request. Please try rephrasing your question or try again later."
829
+
830
+ def disconnect():
831
+ """Cleanly disconnect connections with global pool management."""
832
+ global _mcp_client, _initialized
833
+ disconnect_step = track_workflow_step("agent_disconnect", "Disconnecting MCP client")
834
+
835
+ try:
836
+ # Phase 2 Optimization: Preserve global connections for reuse
837
+ with _global_connection_lock:
838
+ preserved_connections = 0
839
+ for url, client in _global_connection_pool.items():
840
+ try:
841
+ # Keep connections alive but mark as idle
842
+ if hasattr(client, '_last_used'):
843
+ client._last_used = time.time()
844
+ preserved_connections += 1
845
+ except:
846
+ pass
847
+
848
+ complete_workflow_step(disconnect_step, "completed", details={
849
+ "preserved_connections": preserved_connections,
850
+ "optimization": "connection_persistence"
851
+ })
852
+ except Exception as e:
853
+ complete_workflow_step(disconnect_step, "error", details={"error": str(e)})
854
+ finally:
855
+ # Don't reset global state - preserve for next session
856
+ _initialized = False
857
+
858
+ def initialize_session():
859
+ """Initialize the persistent session - alias for initialize_agent."""
860
+ initialize_agent()
861
+
862
+ def is_session_initialized() -> bool:
863
+ """Check if the persistent session is initialized - alias for is_agent_initialized."""
864
+ return is_agent_initialized()
865
+
866
+ # Make sure these are exported for imports
867
+ __all__ = [
868
+ 'run_agent', 'initialize_agent', 'is_agent_initialized', 'disconnect',
869
+ 'initialize_session', 'is_session_initialized',
870
+ 'get_mcp_client', 'get_global_model', 'reset_global_state',
871
+ '_global_tools_cache', '_global_connection_pool', '_global_model_instance',
872
+ '_global_connection_lock', '_global_model_lock'
873
+ ]
874
+
875
+ # Register cleanup function
876
+ def cleanup_global_resources():
877
+ """Cleanup function for graceful shutdown."""
878
+ global _global_connection_pool, _event_loop_manager, _global_connection_lock, _event_loop_lock
879
+
880
+ print("Cleaning up global resources...")
881
+
882
+ with _global_connection_lock:
883
+ for client in _global_connection_pool.values():
884
+ try:
885
+ client.disconnect()
886
+ except:
887
+ pass
888
+ _global_connection_pool.clear()
889
+
890
+ # Shutdown event loop manager
891
+ with _event_loop_lock:
892
+ if _event_loop_manager:
893
+ try:
894
+ _event_loop_manager.shutdown()
895
+ except:
896
+ pass
897
+ _event_loop_manager = None
898
+
899
+ # Register cleanup on exit
900
+ atexit.register(cleanup_global_resources)