aelyazid commited on
Commit
7b1f330
·
verified ·
1 Parent(s): 0fde184

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +68 -117
agent.py CHANGED
@@ -11,109 +11,88 @@ from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
- supabase_url='https://qzydfaroejcpolxfgfim.supabase.co'
21
- supabase_key='eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InF6eWRmYXJvZWpjcG9seGZnZmltIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc0OTUwNTQyMywiZXhwIjoyMDY1MDgxNDIzfQ.IBjtn1tPcogCF6DSf8dgR29aTsC61Qh0XueXYcEWG_Q'
 
 
 
22
  @tool
23
  def multiply(a: int, b: int) -> int:
24
- """Multiply two numbers.
25
- Args:
26
- a: first int
27
- b: second int
28
- """
29
  return a * b
30
 
 
31
  @tool
32
  def add(a: int, b: int) -> int:
33
- """Add two numbers.
34
-
35
- Args:
36
- a: first int
37
- b: second int
38
- """
39
  return a + b
40
 
 
41
  @tool
42
  def subtract(a: int, b: int) -> int:
43
- """Subtract two numbers.
44
-
45
- Args:
46
- a: first int
47
- b: second int
48
- """
49
  return a - b
50
 
 
51
  @tool
52
- def divide(a: int, b: int) -> int:
53
- """Divide two numbers.
54
-
55
- Args:
56
- a: first int
57
- b: second int
58
- """
59
  if b == 0:
60
  raise ValueError("Cannot divide by zero.")
61
  return a / b
62
 
 
63
  @tool
64
  def modulus(a: int, b: int) -> int:
65
- """Get the modulus of two numbers.
66
-
67
- Args:
68
- a: first int
69
- b: second int
70
- """
71
  return a % b
72
 
 
73
  @tool
74
- def wiki_search(query: str) -> str:
75
- """Search Wikipedia for a query and return maximum 2 results.
76
-
77
- Args:
78
- query: The search query."""
79
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
  formatted_search_docs = "\n\n---\n\n".join(
81
  [
82
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
  for doc in search_docs
84
- ])
 
85
  return {"wiki_results": formatted_search_docs}
86
 
 
87
  @tool
88
- def web_search(query: str) -> str:
89
- """Search Tavily for a query and return maximum 3 results.
90
-
91
- Args:
92
- query: The search query."""
93
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
  formatted_search_docs = "\n\n---\n\n".join(
95
  [
96
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
  for doc in search_docs
98
- ])
 
99
  return {"web_results": formatted_search_docs}
100
 
 
101
  @tool
102
- def arvix_search(query: str) -> str:
103
- """Search Arxiv for a query and return maximum 3 result.
104
-
105
- Args:
106
- query: The search query."""
107
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
  formatted_search_docs = "\n\n---\n\n".join(
109
  [
110
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
  for doc in search_docs
112
- ])
 
113
  return {"arvix_results": formatted_search_docs}
114
 
115
 
116
-
117
  # load the system prompt from the file
118
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
  system_prompt = f.read()
@@ -121,24 +100,23 @@ with open("system_prompt.txt", "r", encoding="utf-8") as f:
121
  # System message
122
  sys_msg = SystemMessage(content=system_prompt)
123
 
124
- # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
  supabase: Client = create_client(supabase_url, supabase_key)
127
 
128
  vector_store = SupabaseVectorStore(
129
  client=supabase,
130
- embedding= embeddings,
131
  table_name="documents",
132
  query_name="match_documents_langchain",
133
  )
 
134
  create_retriever_tool = create_retriever_tool(
135
  retriever=vector_store.as_retriever(),
136
  name="Question Search",
137
  description="A tool to retrieve similar questions from a vector store.",
138
  )
139
 
140
-
141
-
142
  tools = [
143
  multiply,
144
  add,
@@ -153,87 +131,60 @@ tools = [
153
  # Build graph function
154
  def build_graph(provider: str = "huggingface"):
155
  """Build the graph"""
156
- # Load environment variables from .env file
157
  if provider == "google":
158
- # Google Gemini
159
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
  elif provider == "groq":
161
- # Groq https://console.groq.com/docs/models
162
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
  elif provider == "huggingface":
164
- # TODO: Add huggingface endpoint
165
  llm = ChatHuggingFace(
166
- llm=HuggingFaceEndpoint(endpoint_url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf"),
167
- temperature=0,
168
  )
169
-
170
-
171
  else:
172
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
- # Bind tools to LLM
174
  llm_with_tools = llm.bind_tools(tools)
175
 
176
- # Node
177
  def assistant(state: MessagesState):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
-
181
- # def retriever(state: MessagesState):
182
- # """Retriever node"""
183
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
- #example_msg = HumanMessage(
185
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
- # )
187
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
-
189
- from langchain_core.messages import AIMessage
190
 
191
  def retriever(state: MessagesState):
192
- query = state["messages"][-1].content
193
- # 1. Embed the query to vector
194
- query_embedding = embeddings.embed_query(query) # list of floats
195
-
196
- # 2. Call the RPC function directly
197
- response = supabase.rpc(
198
- 'match_documents_langchain',
199
- {
200
- 'match_count': 2,
201
- 'query_embedding': query_embedding
202
- }
203
- ).execute()
204
-
205
- docs = response.data
206
- if not docs or len(docs) == 0:
207
- answer = "Sorry, I couldn't find an answer to your question."
208
- else:
209
- content = docs[0]['content'] # get content of the first matched doc
210
- # Extract answer if it has 'Final answer :' pattern
211
- if "Final answer :" in content:
212
- answer = content.split("Final answer :")[-1].strip()
213
  else:
214
- answer = content.strip()
215
-
216
- return {"messages": [AIMessage(content=answer)]}
217
-
 
218
 
219
- # builder = StateGraph(MessagesState)
220
- #builder.add_node("retriever", retriever)
221
- #builder.add_node("assistant", assistant)
222
- #builder.add_node("tools", ToolNode(tools))
223
- #builder.add_edge(START, "retriever")
224
- #builder.add_edge("retriever", "assistant")
225
- #builder.add_conditional_edges(
226
- # "assistant",
227
- # tools_condition,
228
- #)
229
- #builder.add_edge("tools", "assistant")
230
 
231
  builder = StateGraph(MessagesState)
232
  builder.add_node("retriever", retriever)
 
 
 
 
 
 
 
233
 
234
- # Retriever ist Start und Endpunkt
235
  builder.set_entry_point("retriever")
236
  builder.set_finish_point("retriever")
237
 
238
- # Compile graph
239
- return builder.compile()
 
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
+
21
+ supabase_url = 'https://qzydfaroejcpolxfgfim.supabase.co'
22
+ supabase_key = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6InF6eWRmYXJvZWpjcG9seGZnZmltIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImlhdCI6MTc0OTUwNTQyMywiZXhwIjoyMDY1MDgxNDIzfQ.IBjtn1tPcogCF6DSf8dgR29aTsC61Qh0XueXYcEWG_Q'
23
+
24
+
25
  @tool
26
  def multiply(a: int, b: int) -> int:
27
+ """Multiply two numbers."""
 
 
 
 
28
  return a * b
29
 
30
+
31
  @tool
32
  def add(a: int, b: int) -> int:
33
+ """Add two numbers."""
 
 
 
 
 
34
  return a + b
35
 
36
+
37
  @tool
38
  def subtract(a: int, b: int) -> int:
39
+ """Subtract two numbers."""
 
 
 
 
 
40
  return a - b
41
 
42
+
43
  @tool
44
+ def divide(a: int, b: int) -> float:
45
+ """Divide two numbers."""
 
 
 
 
 
46
  if b == 0:
47
  raise ValueError("Cannot divide by zero.")
48
  return a / b
49
 
50
+
51
  @tool
52
  def modulus(a: int, b: int) -> int:
53
+ """Get the modulus of two numbers."""
 
 
 
 
 
54
  return a % b
55
 
56
+
57
  @tool
58
+ def wiki_search(query: str) -> dict:
59
+ """Search Wikipedia for a query and return maximum 2 results."""
 
 
 
60
  search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
61
  formatted_search_docs = "\n\n---\n\n".join(
62
  [
63
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
64
  for doc in search_docs
65
+ ]
66
+ )
67
  return {"wiki_results": formatted_search_docs}
68
 
69
+
70
  @tool
71
+ def web_search(query: str) -> dict:
72
+ """Search Tavily for a query and return maximum 3 results."""
 
 
 
73
  search_docs = TavilySearchResults(max_results=3).invoke(query=query)
74
  formatted_search_docs = "\n\n---\n\n".join(
75
  [
76
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
77
  for doc in search_docs
78
+ ]
79
+ )
80
  return {"web_results": formatted_search_docs}
81
 
82
+
83
  @tool
84
+ def arvix_search(query: str) -> dict:
85
+ """Search Arxiv for a query and return maximum 3 results."""
 
 
 
86
  search_docs = ArxivLoader(query=query, load_max_docs=3).load()
87
  formatted_search_docs = "\n\n---\n\n".join(
88
  [
89
  f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
90
  for doc in search_docs
91
+ ]
92
+ )
93
  return {"arvix_results": formatted_search_docs}
94
 
95
 
 
96
  # load the system prompt from the file
97
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
98
  system_prompt = f.read()
 
100
  # System message
101
  sys_msg = SystemMessage(content=system_prompt)
102
 
103
+ # Build embeddings and vector store client
104
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
105
  supabase: Client = create_client(supabase_url, supabase_key)
106
 
107
  vector_store = SupabaseVectorStore(
108
  client=supabase,
109
+ embedding=embeddings,
110
  table_name="documents",
111
  query_name="match_documents_langchain",
112
  )
113
+
114
  create_retriever_tool = create_retriever_tool(
115
  retriever=vector_store.as_retriever(),
116
  name="Question Search",
117
  description="A tool to retrieve similar questions from a vector store.",
118
  )
119
 
 
 
120
  tools = [
121
  multiply,
122
  add,
 
131
  # Build graph function
132
  def build_graph(provider: str = "huggingface"):
133
  """Build the graph"""
134
+
135
  if provider == "google":
 
136
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
137
  elif provider == "groq":
138
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
139
  elif provider == "huggingface":
 
140
  llm = ChatHuggingFace(
141
+ llm=HuggingFaceEndpoint(endpoint_url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf"),
142
+ temperature=0,
143
  )
 
 
144
  else:
145
  raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
146
+
147
  llm_with_tools = llm.bind_tools(tools)
148
 
 
149
  def assistant(state: MessagesState):
150
  """Assistant node"""
151
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
152
 
153
  def retriever(state: MessagesState):
154
+ query = state["messages"][-1].content
155
+ query_embedding = embeddings.embed_query(query) # list of floats
156
+
157
+ response = supabase.rpc(
158
+ 'match_documents_langchain',
159
+ {
160
+ 'match_count': 2,
161
+ 'query_embedding': query_embedding
162
+ }
163
+ ).execute()
164
+
165
+ docs = response.data
166
+ if not docs or len(docs) == 0:
167
+ answer = "Sorry, I couldn't find an answer to your question."
 
 
 
 
 
 
 
168
  else:
169
+ content = docs[0]['content'] # get content of the first matched doc
170
+ if "Final answer :" in content:
171
+ answer = content.split("Final answer :")[-1].strip()
172
+ else:
173
+ answer = content.strip()
174
 
175
+ return {"messages": [AIMessage(content=answer)]}
 
 
 
 
 
 
 
 
 
 
176
 
177
  builder = StateGraph(MessagesState)
178
  builder.add_node("retriever", retriever)
179
+ # If you want to integrate assistant and tools, uncomment and add edges accordingly
180
+ # builder.add_node("assistant", assistant)
181
+ # builder.add_node("tools", ToolNode(tools))
182
+ # builder.add_edge(START, "retriever")
183
+ # builder.add_edge("retriever", "assistant")
184
+ # builder.add_conditional_edges("assistant", tools_condition)
185
+ # builder.add_edge("tools", "assistant")
186
 
 
187
  builder.set_entry_point("retriever")
188
  builder.set_finish_point("retriever")
189
 
190
+ return builder.compile()