ngockhoinguyenpy commited on
Commit
25fe8bb
·
verified ·
1 Parent(s): cbf4732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -48
app.py CHANGED
@@ -1,74 +1,222 @@
1
- from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
2
- import datetime
3
- import requests
4
- import pytz
5
  import yaml
6
- from tools.final_answer import FinalAnswerTool
 
7
 
8
- from Gradio_UI import GradioUI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Below is an example of a tool that does nothing. Amaze us with your creativity !
11
- @tool
12
- def fetch_zen_quote() -> str:
13
- """Fetches a random zen quote and returns the 'h' field from the JSON response.
14
-
15
- Returns:
16
- A string containing the formatted quote with the author.
17
- """
18
- response = requests.get("https://zenquotes.io/api/random")
19
- json_data = response.json()
20
-
21
- # Extract the 'h' field from the JSON response
22
- quote_html = json_data[0]["h"]
23
-
24
- return quote_html
25
 
26
- @tool
27
- def get_current_time_in_timezone(timezone: str) -> str:
28
- """A tool that fetches the current local time in a specified timezone.
 
 
 
 
 
 
 
 
 
 
 
 
29
  Args:
30
- timezone: A string representing a valid timezone (e.g., 'America/New_York').
 
 
 
31
  """
32
  try:
33
- # Create timezone object
34
- tz = pytz.timezone(timezone)
35
- # Get current time in that timezone
36
- local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
37
- return f"The current local time in {timezone} is: {local_time}"
38
- except Exception as e:
39
- return f"Error fetching time for timezone '{timezone}': {str(e)}"
40
 
 
 
 
 
41
 
42
- final_answer = FinalAnswerTool()
43
 
44
- # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
45
- # model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
46
 
47
- model = HfApiModel(
48
- max_tokens=2096,
49
- temperature=0.5,
50
- model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
51
- custom_role_conversions=None,
52
- )
 
 
 
 
 
 
 
 
 
 
53
 
 
 
54
 
55
- # Import tool from Hub
56
- image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with open("prompts.yaml", 'r') as stream:
59
  prompt_templates = yaml.safe_load(stream)
60
-
 
61
  agent = CodeAgent(
62
  model=model,
63
- tools=[final_answer, image_generation_tool, fetch_zen_quote], ## add your tools here (don't remove final answer)
64
  max_steps=6,
65
  verbosity_level=1,
66
  grammar=None,
67
  planning_interval=None,
68
- name=None,
69
- description=None,
70
  prompt_templates=prompt_templates
71
  )
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- GradioUI(agent).launch()
 
 
1
+ import feedparser
2
+ import urllib.parse
 
 
3
  import yaml
4
+ import gradio as gr
5
+ from smolagents import CodeAgent, HfApiModel, tool
6
 
7
+ # @tool
8
+ # def fetch_latest_arxiv_papers(keywords: list, num_results: int = 3) -> list:
9
+ # """Fetches the latest research papers from arXiv based on provided keywords.
10
+
11
+ # Args:
12
+ # keywords: A list of keywords to search for relevant papers.
13
+ # num_results: The number of papers to fetch (default is 3).
14
+
15
+ # Returns:
16
+ # A list of dictionaries containing:
17
+ # - "title": The title of the research paper.
18
+ # - "authors": The authors of the paper.
19
+ # - "year": The publication year.
20
+ # - "abstract": A summary of the research paper.
21
+ # - "link": A direct link to the paper on arXiv.
22
+ # """
23
+ # try:
24
+ # print(f"DEBUG: Searching arXiv papers with keywords: {keywords}") # Debug input
25
+
26
+ # #Properly format query with +AND+ for multiple keywords
27
+ # query = "+AND+".join([f"all:{kw}" for kw in keywords])
28
+ # query_encoded = urllib.parse.quote(query) # Encode spaces and special characters
29
+
30
+ # url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results={num_results}&sortBy=submittedDate&sortOrder=descending"
31
+
32
+ # print(f"DEBUG: Query URL - {url}") # Debug URL
33
+
34
+ # feed = feedparser.parse(url)
35
+
36
+ # papers = []
37
+ # for entry in feed.entries:
38
+ # papers.append({
39
+ # "title": entry.title,
40
+ # "authors": ", ".join(author.name for author in entry.authors),
41
+ # "year": entry.published[:4], # Extract year
42
+ # "abstract": entry.summary,
43
+ # "link": entry.link
44
+ # })
45
+
46
+ # return papers
47
+
48
+ # except Exception as e:
49
+ # print(f"ERROR: {str(e)}") # Debug errors
50
+ # return [f"Error fetching research papers: {str(e)}"]
51
+
52
+ from rank_bm25 import BM25Okapi
53
+ import nltk
54
+
55
+ import os
56
+ import shutil
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ nltk_data_path = os.path.join(nltk.data.path[0], "tokenizers", "punkt")
60
+ if os.path.exists(nltk_data_path):
61
+ shutil.rmtree(nltk_data_path) # Remove corrupted version
62
+
63
+ print("✅ Removed old NLTK 'punkt' data. Reinstalling...")
64
+
65
+ # ✅ Step 2: Download the correct 'punkt' tokenizer
66
+ nltk.download("punkt_tab")
67
+
68
+ print("✅ Successfully installed 'punkt'!")
69
+
70
+
71
+ @tool # Register the function properly as a SmolAgents tool
72
+ def fetch_latest_arxiv_papers(keywords: list, num_results: int = 5) -> list:
73
+ """Fetches and ranks arXiv papers using BM25 keyword relevance.
74
  Args:
75
+ keywords: List of keywords for search.
76
+ num_results: Number of results to return.
77
+ Returns:
78
+ List of the most relevant papers based on BM25 ranking.
79
  """
80
  try:
81
+ print(f"DEBUG: Searching arXiv papers with keywords: {keywords}")
 
 
 
 
 
 
82
 
83
+ # Use a general keyword search (without `ti:` and `abs:`)
84
+ query = "+AND+".join([f"all:{kw}" for kw in keywords])
85
+ query_encoded = urllib.parse.quote(query)
86
+ url = f"http://export.arxiv.org/api/query?search_query={query_encoded}&start=0&max_results=50&sortBy=submittedDate&sortOrder=descending"
87
 
88
+ print(f"DEBUG: Query URL - {url}")
89
 
90
+ feed = feedparser.parse(url)
91
+ papers = []
92
 
93
+ # Extract papers from arXiv
94
+ for entry in feed.entries:
95
+ papers.append({
96
+ "title": entry.title,
97
+ "authors": ", ".join(author.name for author in entry.authors),
98
+ "year": entry.published[:4],
99
+ "abstract": entry.summary,
100
+ "link": entry.link
101
+ })
102
+
103
+ if not papers:
104
+ return [{"error": "No results found. Try different keywords."}]
105
+
106
+ # Apply BM25 ranking
107
+ tokenized_corpus = [nltk.word_tokenize(paper["title"].lower() + " " + paper["abstract"].lower()) for paper in papers]
108
+ bm25 = BM25Okapi(tokenized_corpus)
109
 
110
+ tokenized_query = nltk.word_tokenize(" ".join(keywords).lower())
111
+ scores = bm25.get_scores(tokenized_query)
112
 
113
+ # Sort papers based on BM25 score
114
+ ranked_papers = sorted(zip(papers, scores), key=lambda x: x[1], reverse=True)
115
 
116
+ # Return the most relevant ones
117
+ return [paper[0] for paper in ranked_papers[:num_results]]
118
+
119
+ except Exception as e:
120
+ print(f"ERROR: {str(e)}")
121
+ return [{"error": f"Error fetching research papers: {str(e)}"}]
122
+
123
+
124
+ # AI Model
125
+ model = HfApiModel(
126
+ max_tokens=2096,
127
+ temperature=0.5,
128
+ model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
129
+ custom_role_conversions=None,
130
+ )
131
+
132
+ # Load prompt templates
133
  with open("prompts.yaml", 'r') as stream:
134
  prompt_templates = yaml.safe_load(stream)
135
+
136
+ # Create the AI Agent
137
  agent = CodeAgent(
138
  model=model,
139
+ tools=[fetch_latest_arxiv_papers], # Properly registered tool
140
  max_steps=6,
141
  verbosity_level=1,
142
  grammar=None,
143
  planning_interval=None,
144
+ name="ScholarAgent",
145
+ description="An AI agent that fetches the latest research papers from arXiv based on user-defined keywords and filters.",
146
  prompt_templates=prompt_templates
147
  )
148
 
149
+ # # Define Gradio Search Function
150
+ # def search_papers(user_input):
151
+ # keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()] # Ensure valid keywords
152
+ # print(f"DEBUG: Received input keywords - {keywords}") # Debug user input
153
+
154
+ # if not keywords:
155
+ # print("DEBUG: No valid keywords provided.")
156
+ # return "Error: Please enter at least one valid keyword."
157
+
158
+ # results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
159
+ # print(f"DEBUG: Results received - {results}") # Debug function output
160
+
161
+ # if isinstance(results, list) and results and isinstance(results[0], dict):
162
+ # #Format output with better readability and clarity
163
+ # formatted_results = "\n\n".join([
164
+ # f"---\n\n"
165
+ # f"📌 **Title:**\n{paper['title']}\n\n"
166
+ # f"👨‍🔬 **Authors:**\n{paper['authors']}\n\n"
167
+ # f"📅 **Year:** {paper['year']}\n\n"
168
+ # f"📖 **Abstract:**\n{paper['abstract'][:500]}... *(truncated for readability)*\n\n"
169
+ # f"[🔗 Read Full Paper]({paper['link']})\n\n"
170
+ # for paper in results
171
+ # ])
172
+ # return formatted_results
173
+
174
+ # print("DEBUG: No results found.")
175
+ # return "No results found. Try different keywords."
176
+
177
+ #Search Papers
178
+ def search_papers(user_input):
179
+ keywords = [kw.strip() for kw in user_input.split(",") if kw.strip()] # Ensure valid keywords
180
+ print(f"DEBUG: Received input keywords - {keywords}") # Debug user input
181
+
182
+ if not keywords:
183
+ print("DEBUG: No valid keywords provided.")
184
+ return "Error: Please enter at least one valid keyword."
185
+
186
+ results = fetch_latest_arxiv_papers(keywords, num_results=3) # Fetch 3 results
187
+ print(f"DEBUG: Results received - {results}") # Debug function output
188
+
189
+ # ✅ Check if the API returned an error
190
+ if isinstance(results, list) and len(results) > 0 and "error" in results[0]:
191
+ return results[0]["error"] # Return the error message directly
192
+
193
+ # ✅ Format results only if valid papers exist
194
+ if isinstance(results, list) and results and isinstance(results[0], dict):
195
+ formatted_results = "\n\n".join([
196
+ f"---\n\n"
197
+ f"📌 **Title:** {paper['title']}\n\n"
198
+ f"👨‍🔬 **Authors:** {paper['authors']}\n\n"
199
+ f"📅 **Year:** {paper['year']}\n\n"
200
+ f"📖 **Abstract:** {paper['abstract'][:500]}... *(truncated for readability)*\n\n"
201
+ f"[🔗 Read Full Paper]({paper['link']})\n\n"
202
+ for paper in results
203
+ ])
204
+ return formatted_results
205
+
206
+ print("DEBUG: No results found.")
207
+ return "No results found. Try different keywords."
208
+
209
+
210
+ # Create Gradio UI
211
+ with gr.Blocks() as demo:
212
+ gr.Markdown("# ScholarAgent")
213
+ keyword_input = gr.Textbox(label="Enter keywords (comma-separated)", placeholder="e.g., deep learning, reinforcement learning")
214
+ output_display = gr.Markdown()
215
+ search_button = gr.Button("Search")
216
+
217
+ search_button.click(search_papers, inputs=[keyword_input], outputs=[output_display])
218
+
219
+ print("DEBUG: Gradio UI is running. Waiting for user input...")
220
 
221
+ # Launch Gradio App
222
+ demo.launch()