SmolAgentsv2 / run.py
CultriX's picture
Update run.py
9f8395f verified
raw
history blame
13.1 kB
import argparse
import os
import threading
import sys
import logging
from io import StringIO
from contextlib import redirect_stdout, redirect_stderr
from dotenv import load_dotenv
from huggingface_hub import login
from scripts.text_inspector_tool import TextInspectorTool
from scripts.text_web_browser import (
ArchiveSearchTool,
FinderTool,
FindNextTool,
PageDownTool,
PageUpTool,
SimpleTextBrowser,
VisitTool,
)
from scripts.visual_qa import visualizer
from smolagents import (
CodeAgent,
ToolCallingAgent,
LiteLLMModel,
DuckDuckGoSearchTool,
Tool,
)
AUTHORIZED_IMPORTS = [
"shell_gpt", "sgpt", "openai", "requests", "zipfile", "os", "pandas", "numpy", "sympy", "json", "bs4",
"pubchempy", "xml", "yahoo_finance", "Bio", "sklearn", "scipy", "pydub",
"yaml", "string", "secrets", "io", "PIL", "chess", "PyPDF2", "pptx", "torch", "datetime", "fractions", "csv",
]
append_answer_lock = threading.Lock()
class StreamingHandler(logging.Handler):
"""Custom logging handler that captures agent logs and sends them to callbacks."""
def __init__(self):
super().__init__()
self.callbacks = []
def add_callback(self, callback):
self.callbacks.append(callback)
def emit(self, record):
msg = self.format(record)
# Check if the message is actually different or non-empty after stripping
# to avoid sending redundant empty strings, though `highlight_text` in app.py handles empty.
if msg.strip():
for callback in self.callbacks:
callback(msg + '\n') # Add newline to ensure distinct lines are processed by app.py's splitter
class StreamingCapture(StringIO):
"""Captures stdout/stderr and sends content to callbacks in real-time."""
def __init__(self):
super().__init__()
self.callbacks = []
def add_callback(self, callback):
self.callbacks.append(callback)
def write(self, s):
# Pass the raw string 's' directly to callbacks immediately
if s: # Only send if there's actual content
for callback in self.callbacks:
callback(s)
super().write(s) # Still write to the underlying StringIO buffer
def flush(self):
super().flush()
def create_agent(
model_id="gpt-4o-mini",
hf_token=None,
openai_api_key=None,
serpapi_key=None,
api_endpoint=None,
custom_api_endpoint=None,
custom_api_key=None,
search_provider="serper",
search_api_key=None,
custom_search_url=None
):
print("[DEBUG] Creating agent with model_id:", model_id)
if hf_token:
print("[DEBUG] Logging into HuggingFace")
try:
login(hf_token)
except Exception as e:
print(f"[ERROR] Failed to log into HuggingFace: {e}")
model_params = {
"model_id": model_id,
"custom_role_conversions": {"tool-call": "assistant", "tool-response": "user"},
"max_completion_tokens": 8192,
}
if model_id == "gpt-4o-mini":
model_params["reasoning_effort"] = "high"
# Determine which API key to use based on the model_id
if "openai" in model_id.lower() and openai_api_key:
print("[DEBUG] Using OpenAI API key for OpenAI model")
model_params["api_key"] = openai_api_key
elif custom_api_endpoint and custom_api_key:
print("[DEBUG] Using custom API endpoint:", custom_api_endpoint)
model_params["base_url"] = custom_api_endpoint
model_params["api_key"] = custom_api_key
elif api_endpoint and openai_api_key: # Fallback to default OpenAI if custom not specified
print("[DEBUG] Using default API endpoint:", api_endpoint)
model_params["base_url"] = api_endpoint
model_params["api_key"] = openai_api_key
# It's important that if an API key is missing for the chosen model, it fails here or upstream.
model = LiteLLMModel(**model_params)
print("[DEBUG] Model initialized")
text_limit = 100000
user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
browser_config = {
"viewport_size": 1024 * 5,
"downloads_folder": "downloads_folder",
"request_kwargs": {
"headers": {"User-Agent": user_agent},
"timeout": 300,
},
"serpapi_key": serpapi_key, # This will be used by ArchiveSearchTool if SerpAPI is enabled
}
os.makedirs(f"./{browser_config['downloads_folder']}", exist_ok=True)
browser = SimpleTextBrowser(**browser_config)
print("[DEBUG] Browser initialized")
search_tool = None
if search_provider == "searxng":
print("[DEBUG] Using DuckDuckGoSearchTool (acting as a generic web search) for SearxNG context.")
search_tool = DuckDuckGoSearchTool()
if custom_search_url:
# Note: As mentioned before, DuckDuckGoSearchTool doesn't natively use a custom base_url
# for a completely different search engine like SearxNG. This line will likely have no effect.
# For true SearxNG integration, you'd need a custom tool or a modified DuckDuckGoSearchTool
# that knows how to query SearxNG instances.
print(f"[WARNING] DuckDuckGoSearchTool does not directly support 'custom_search_url' for SearxNG. Consider a dedicated SearxNG tool.")
# search_tool.base_url = custom_search_url # This line is often not effective for DDCSTool
elif search_provider == "serper":
print("[DEBUG] Using DuckDuckGoSearchTool (acting as a generic web search) for Serper context.")
search_tool = DuckDuckGoSearchTool() # You would need a separate SerperTool for direct Serper API calls.
if search_api_key:
print("[DEBUG] Serper API Key provided. Ensure your search tool (if custom) uses it.")
# If you had a dedicated SerperTool, you'd pass search_api_key to it.
# e.g., search_tool = SerperTool(api_key=search_api_key)
else:
print("[DEBUG] No specific search provider selected, or provider not directly supported. Defaulting to DuckDuckGoSearchTool.")
search_tool = DuckDuckGoSearchTool()
WEB_TOOLS = [
search_tool,
VisitTool(browser),
PageUpTool(browser),
PageDownTool(browser),
FinderTool(browser),
FindNextTool(browser),
ArchiveSearchTool(browser), # This tool specifically uses serpapi_key from browser_config
TextInspectorTool(model, text_limit),
]
text_webbrowser_agent = ToolCallingAgent(
model=model,
tools=[tool for tool in WEB_TOOLS if tool is not None], # Filter out None if search_tool was not set
max_steps=20,
verbosity_level=3, # Keep this high for detailed output
planning_interval=4,
name="search_agent",
description="A team member that will search the internet to answer your question.",
provide_run_summary=True,
)
text_webbrowser_agent.prompt_templates["managed_agent"]["task"] += """You can navigate to .txt online files.
If a non-html page is in another format, especially .pdf or a Youtube video, use tool 'inspect_file_as_text' to inspect it.
Additionally, if after some searching you find out that you need more information to answer the question, you can use `final_answer` with your request for clarification as argument to request for more information."""
manager_agent = CodeAgent(
model=model,
tools=[visualizer, TextInspectorTool(model, text_limit)],
max_steps=16,
verbosity_level=3, # Keep this high for detailed output
additional_authorized_imports=AUTHORIZED_IMPORTS,
planning_interval=4,
managed_agents=[text_webbrowser_agent],
)
print("[DEBUG] Agent fully initialized")
return manager_agent
def run_agent_with_streaming(agent, question, stream_callback=None):
"""Run agent and stream output in real-time"""
# Set up logging capture
log_handler = StreamingHandler()
if stream_callback:
log_handler.add_callback(stream_callback)
# Add handler to root logger and smolagents loggers
root_logger = logging.getLogger()
smolagents_logger = logging.getLogger('smolagents')
# Store original handlers and levels
original_root_handlers = root_logger.handlers[:]
original_smolagents_handlers = smolagents_logger.handlers[:]
original_root_level = root_logger.level
original_smolagents_level = smolagents_logger.level
# Store original stdout/stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
stdout_capture = StreamingCapture()
stderr_capture = StreamingCapture()
if stream_callback:
stdout_capture.add_callback(stream_callback)
stderr_capture.add_callback(stream_callback)
try:
# Configure logging to capture everything
# Set logging levels very low to capture all verbose output
root_logger.setLevel(logging.DEBUG)
for handler in root_logger.handlers: # Remove existing handlers to avoid duplicate output
root_logger.removeHandler(handler)
root_logger.addHandler(log_handler)
smolagents_logger.setLevel(logging.DEBUG)
for handler in smolagents_logger.handlers: # Remove existing handlers
smolagents_logger.removeHandler(handler)
smolagents_logger.addHandler(log_handler)
# Redirect stdout/stderr
sys.stdout = stdout_capture
sys.stderr = stderr_capture
if stream_callback:
stream_callback(f"[STARTING] Running agent with question: {question}\n")
answer = agent.run(question)
if stream_callback:
stream_callback(f"[COMPLETED] {answer}\n")
return answer
except Exception as e:
error_msg = f"[ERROR] Exception occurred: {str(e)}\n"
if stream_callback:
stream_callback(error_msg)
raise
finally:
# Restore original logging configuration
root_logger.handlers = original_root_handlers
root_logger.setLevel(original_root_level)
smolagents_logger.handlers = original_smolagents_handlers
smolagents_logger.setLevel(original_smolagents_level)
# Restore original stdout/stderr
sys.stdout = original_stdout
sys.stderr = original_stderr
# Ensure any remaining buffered output is flushed (especially important for stdout/stderr)
stdout_capture.flush()
stderr_capture.flush()
def main():
print("[DEBUG] Loading environment variables")
load_dotenv(override=True)
parser = argparse.ArgumentParser()
parser.add_argument("--gradio", action="store_true", help="Launch Gradio interface")
parser.add_argument("question", type=str, nargs='?', help="Question to ask (CLI mode)")
parser.add_argument("--model-id", type=str, default="gpt-4o-mini")
parser.add_argument("--hf-token", type=str, default=os.getenv("HF_TOKEN"))
parser.add_argument("--serpapi-key", type=str, default=os.getenv("SERPAPI_API_KEY"))
parser.add_argument("--openai-api-key", type=str, default=os.getenv("OPENAI_API_KEY")) # Added
parser.add_argument("--api-endpoint", type=str, default=os.getenv("API_ENDPOINT", "https://api.openai.com/v1")) # Added
parser.add_argument("--custom-api-endpoint", type=str, default=None)
parser.add_argument("--custom-api-key", type=str, default=None)
parser.add_argument("--search-provider", type=str, default="searxng") # Changed default to searxng for consistency
parser.add_argument("--search-api-key", type=str, default=None)
parser.add_argument("--custom-search-url", type=str, default="https://search.endorisk.nl/search") # Changed default for consistency
args = parser.parse_args()
print("[DEBUG] CLI arguments parsed:", args)
if args.gradio:
print("Please run `app.py` directly to launch the Gradio interface.")
return
else:
# CLI mode
if not args.question:
print("Error: Question required for CLI mode")
return
agent = create_agent(
model_id=args.model_id,
hf_token=args.hf_token,
openai_api_key=args.openai_api_key,
serpapi_key=args.serpapi_key,
api_endpoint=args.api_endpoint,
custom_api_endpoint=args.custom_api_endpoint,
custom_api_key=args.custom_api_key,
search_provider=args.search_provider,
search_api_key=args.search_api_key,
custom_search_url=args.custom_search_url,
)
print("[DEBUG] Running agent...")
def print_stream(text):
print(text, end='', flush=True)
answer = run_agent_with_streaming(agent, args.question, print_stream)
print(f"\n\nGot this answer: {answer}")
if __name__ == "__main__":
main()