File size: 5,521 Bytes
04ffb15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import os
from typing import Dict, Any
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_core.tools import tool
import logging
logger = logging.getLogger(__name__)
@tool
def wiki_search(query: str) -> Dict[str, str]:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
try:
logger.info(f"Searching Wikipedia for: {query}")
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
if not search_docs:
logger.warning("No Wikipedia results found")
return {"wiki_results": "No results found"}
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
logger.info(f"Found {len(search_docs)} Wikipedia results")
return {"wiki_results": formatted_search_docs}
except Exception as e:
logger.error(f"Error searching Wikipedia: {str(e)}")
return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
@tool
def web_search(query: str) -> Dict[str, str]:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
try:
logger.info(f"Searching web for: {query}")
search = TavilySearchResults(max_results=3)
search_docs = search.invoke({"query": query})
if not search_docs:
logger.warning("No web results found")
return {"web_results": "No results found"}
if isinstance(search_docs, list):
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.get("source", "")}" page="{doc.get("page", "")}"/>\n{doc.get("content", "")}\n</Document>'
for doc in search_docs
])
logger.info(f"Found {len(search_docs)} web results")
return {"web_results": formatted_search_docs}
logger.warning(f"Unexpected response format from Tavily: {type(search_docs)}")
return {"web_results": f"Error: Unexpected response format from Tavily"}
except Exception as e:
logger.error(f"Error searching web: {str(e)}")
return {"web_results": f"Error searching web: {str(e)}"}
@tool
def arxiv_search(query: str) -> Dict[str, str]:
"""Search Arxiv for a query and return maximum 3 results.
Args:
query: The search query."""
try:
logger.info(f"Searching Arxiv for: {query}")
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
if not search_docs:
logger.warning("No Arxiv results found")
return {"arxiv_results": "No results found"}
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
logger.info(f"Found {len(search_docs)} Arxiv results")
return {"arxiv_results": formatted_search_docs}
except Exception as e:
logger.error(f"Error searching Arxiv: {str(e)}")
return {"arxiv_results": f"Error searching Arxiv: {str(e)}"}
@tool
def wiki_api_search(query: str) -> Dict[str, str]:
"""Search Wikipedia using API wrapper for better results.
Args:
query: The search query."""
try:
logger.info(f"Searching Wikipedia API for: {query}")
wikipedia = WikipediaAPIWrapper(top_k_results=3, doc_content_chars_max=4000)
results = wikipedia.run(query)
if not results or results.strip() == "No good Wikipedia Search Result was found":
logger.warning("No Wikipedia API results found")
return {"wiki_api_results": "No results found"}
logger.info(f"Found Wikipedia API results")
return {"wiki_api_results": results}
except Exception as e:
logger.error(f"Error searching Wikipedia API: {str(e)}")
return {"wiki_api_results": f"Error searching Wikipedia API: {str(e)}"}
# List of all search tools
SEARCH_TOOLS = [wiki_search, web_search, arxiv_search, wiki_api_search]
class SearchTools:
"""Wrapper class for search tools to provide a unified interface"""
def __init__(self):
"""Initialize search tools"""
pass
def search_wikipedia(self, query: str) -> str:
"""Search Wikipedia and return formatted results"""
result = wiki_search(query)
return result.get("wiki_results", "")
def search_wikipedia_api(self, query: str) -> str:
"""Search Wikipedia using API wrapper and return formatted results"""
result = wiki_api_search(query)
return result.get("wiki_api_results", "")
def search_web(self, query: str) -> str:
"""Search web and return formatted results"""
result = web_search(query)
return result.get("web_results", "")
def search_arxiv(self, query: str) -> str:
"""Search Arxiv and return formatted results"""
result = arxiv_search(query)
return result.get("arxiv_results", "") |