File size: 5,521 Bytes
04ffb15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from typing import Dict, Any
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_core.tools import tool
import logging

logger = logging.getLogger(__name__)

@tool
def wiki_search(query: str) -> Dict[str, str]:
    """Search Wikipedia for a query and return maximum 2 results.
    Args:
        query: The search query."""
    try:
        logger.info(f"Searching Wikipedia for: {query}")
        search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
        if not search_docs:
            logger.warning("No Wikipedia results found")
            return {"wiki_results": "No results found"}
            
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
                for doc in search_docs
            ])
        logger.info(f"Found {len(search_docs)} Wikipedia results")
        return {"wiki_results": formatted_search_docs}
    except Exception as e:
        logger.error(f"Error searching Wikipedia: {str(e)}")
        return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}

@tool
def web_search(query: str) -> Dict[str, str]:
    """Search Tavily for a query and return maximum 3 results.
    Args:
        query: The search query."""
    try:
        logger.info(f"Searching web for: {query}")
        search = TavilySearchResults(max_results=3)
        search_docs = search.invoke({"query": query})
        
        if not search_docs:
            logger.warning("No web results found")
            return {"web_results": "No results found"}
            
        if isinstance(search_docs, list):
            formatted_search_docs = "\n\n---\n\n".join(
                [
                    f'<Document source="{doc.get("source", "")}" page="{doc.get("page", "")}"/>\n{doc.get("content", "")}\n</Document>'
                    for doc in search_docs
                ])
            logger.info(f"Found {len(search_docs)} web results")
            return {"web_results": formatted_search_docs}
        logger.warning(f"Unexpected response format from Tavily: {type(search_docs)}")
        return {"web_results": f"Error: Unexpected response format from Tavily"}
    except Exception as e:
        logger.error(f"Error searching web: {str(e)}")
        return {"web_results": f"Error searching web: {str(e)}"}

@tool
def arxiv_search(query: str) -> Dict[str, str]:
    """Search Arxiv for a query and return maximum 3 results.
    Args:
        query: The search query."""
    try:
        logger.info(f"Searching Arxiv for: {query}")
        search_docs = ArxivLoader(query=query, load_max_docs=3).load()
        if not search_docs:
            logger.warning("No Arxiv results found")
            return {"arxiv_results": "No results found"}
            
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata.get("source", "")}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
                for doc in search_docs
            ])
        logger.info(f"Found {len(search_docs)} Arxiv results")
        return {"arxiv_results": formatted_search_docs}
    except Exception as e:
        logger.error(f"Error searching Arxiv: {str(e)}")
        return {"arxiv_results": f"Error searching Arxiv: {str(e)}"}

@tool
def wiki_api_search(query: str) -> Dict[str, str]:
    """Search Wikipedia using API wrapper for better results.
    Args:
        query: The search query."""
    try:
        logger.info(f"Searching Wikipedia API for: {query}")
        wikipedia = WikipediaAPIWrapper(top_k_results=3, doc_content_chars_max=4000)
        results = wikipedia.run(query)
        
        if not results or results.strip() == "No good Wikipedia Search Result was found":
            logger.warning("No Wikipedia API results found")
            return {"wiki_api_results": "No results found"}
            
        logger.info(f"Found Wikipedia API results")
        return {"wiki_api_results": results}
    except Exception as e:
        logger.error(f"Error searching Wikipedia API: {str(e)}")
        return {"wiki_api_results": f"Error searching Wikipedia API: {str(e)}"}

# List of all search tools
SEARCH_TOOLS = [wiki_search, web_search, arxiv_search, wiki_api_search]

class SearchTools:
    """Wrapper class for search tools to provide a unified interface"""
    
    def __init__(self):
        """Initialize search tools"""
        pass
    
    def search_wikipedia(self, query: str) -> str:
        """Search Wikipedia and return formatted results"""
        result = wiki_search(query)
        return result.get("wiki_results", "")
    
    def search_wikipedia_api(self, query: str) -> str:
        """Search Wikipedia using API wrapper and return formatted results"""
        result = wiki_api_search(query)
        return result.get("wiki_api_results", "")
    
    def search_web(self, query: str) -> str:
        """Search web and return formatted results"""
        result = web_search(query)
        return result.get("web_results", "")
    
    def search_arxiv(self, query: str) -> str:
        """Search Arxiv and return formatted results"""
        result = arxiv_search(query)
        return result.get("arxiv_results", "")