rakesh-dvg commited on
Commit
2f0ef1f
·
verified ·
1 Parent(s): 058a977

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +60 -155
agent.py CHANGED
@@ -1,184 +1,101 @@
1
- """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
 
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  from langchain_groq import ChatGroq
9
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
 
19
  load_dotenv()
20
 
21
- import os
22
- from supabase import create_client
 
23
 
24
- def get_supabase_client():
25
- supabase_url = os.environ.get("SUPABASE_URL")
26
- supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
27
- if not supabase_url or not supabase_key:
28
- raise ValueError("Supabase URL or SERVICE_KEY environment variable not set.")
29
- return create_client(supabase_url, supabase_key)
30
 
31
- # Then call get_supabase_client() inside functions, NOT at module-level.
32
- print(f"SUPABASE_URL: {os.environ.get('SUPABASE_URL')[:10]}..." if os.environ.get('SUPABASE_URL') else "SUPABASE_URL not set")
33
- print(f"SUPABASE_SERVICE_KEY: {os.environ.get('SUPABASE_SERVICE_KEY')[:10]}..." if os.environ.get('SUPABASE_SERVICE_KEY') else "SUPABASE_SERVICE_KEY not set")
34
 
 
 
 
 
35
 
36
 
37
  @tool
38
  def multiply(a: int, b: int) -> int:
39
- """Multiply two numbers.
40
-
41
- Args:
42
- a: first int
43
- b: second int
44
- """
45
  return a * b
46
 
47
  @tool
48
  def add(a: int, b: int) -> int:
49
- """Add two numbers.
50
-
51
- Args:
52
- a: first int
53
- b: second int
54
- """
55
  return a + b
56
 
57
  @tool
58
  def subtract(a: int, b: int) -> int:
59
- """Subtract two numbers.
60
-
61
- Args:
62
- a: first int
63
- b: second int
64
- """
65
  return a - b
66
 
67
  @tool
68
  def divide(a: int, b: int) -> int:
69
- """Divide two numbers.
70
-
71
- Args:
72
- a: first int
73
- b: second int
74
- """
75
  if b == 0:
76
  raise ValueError("Cannot divide by zero.")
77
  return a / b
78
 
79
  @tool
80
  def modulus(a: int, b: int) -> int:
81
- """Get the modulus of two numbers.
82
-
83
- Args:
84
- a: first int
85
- b: second int
86
- """
87
  return a % b
88
 
89
  @tool
90
  def wiki_search(query: str) -> str:
91
- """Search Wikipedia for a query and return maximum 2 results.
92
-
93
- Args:
94
- query: The search query."""
95
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
96
- formatted_search_docs = "\n\n---\n\n".join(
97
- [
98
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
99
- for doc in search_docs
100
- ])
101
- return {"wiki_results": formatted_search_docs}
102
 
103
  @tool
104
  def web_search(query: str) -> str:
105
- """Search Tavily for a query and return maximum 3 results.
106
-
107
- Args:
108
- query: The search query."""
109
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
110
- formatted_search_docs = "\n\n---\n\n".join(
111
- [
112
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
113
- for doc in search_docs
114
- ])
115
- return {"web_results": formatted_search_docs}
116
 
117
  @tool
118
  def arvix_search(query: str) -> str:
119
- """Search Arxiv for a query and return maximum 3 result.
120
-
121
- Args:
122
- query: The search query."""
123
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
124
- formatted_search_docs = "\n\n---\n\n".join(
125
- [
126
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
127
- for doc in search_docs
128
- ])
129
- return {"arvix_results": formatted_search_docs}
130
 
131
 
 
132
 
133
- # load the system prompt from the file
134
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
135
  system_prompt = f.read()
136
 
137
- # System message
138
  sys_msg = SystemMessage(content=system_prompt)
139
 
140
- # build a retriever
141
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
142
- supabase: Client = create_client(
143
- os.environ.get("SUPABASE_URL"),
144
- os.environ.get("SUPABASE_SERVICE_KEY"))
145
- vector_store = SupabaseVectorStore(
146
- client=supabase,
147
- embedding= embeddings,
148
- table_name="documents",
149
- query_name="match_documents_langchain",
150
- )
151
- create_retriever_tool = create_retriever_tool(
152
- retriever=vector_store.as_retriever(),
153
- name="Question Search",
154
- description="A tool to retrieve similar questions from a vector store.",
155
- )
156
-
157
-
158
-
159
- tools = [
160
- multiply,
161
- add,
162
- subtract,
163
- divide,
164
- modulus,
165
- wiki_search,
166
- web_search,
167
- arvix_search,
168
- ]
169
-
170
- # Build graph function
171
  def build_graph(provider: str = "groq"):
172
- """Build the graph"""
173
- # Load environment variables from .env file
 
 
 
 
 
 
 
 
 
 
 
 
174
  if provider == "google":
175
- # Google Gemini
176
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
177
  elif provider == "groq":
178
- # Groq https://console.groq.com/docs/models
179
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
180
  elif provider == "huggingface":
181
- # TODO: Add huggingface endpoint
182
  llm = ChatHuggingFace(
183
  llm=HuggingFaceEndpoint(
184
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
@@ -186,45 +103,33 @@ def build_graph(provider: str = "groq"):
186
  ),
187
  )
188
  else:
189
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
190
- # Bind tools to LLM
191
  llm_with_tools = llm.bind_tools(tools)
192
 
193
- # Node
194
  def assistant(state: MessagesState):
195
- """Assistant node"""
196
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
197
-
198
  def retriever(state: MessagesState):
199
- """Retriever node"""
200
- similar_question = vector_store.similarity_search(state["messages"][0].content)
201
- example_msg = HumanMessage(
202
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
203
- )
204
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
205
-
206
- builder = StateGraph(MessagesState)
207
- builder.add_node("retriever", retriever)
208
- builder.add_node("assistant", assistant)
209
- builder.add_node("tools", ToolNode(tools))
210
- builder.add_edge(START, "retriever")
211
- builder.add_edge("retriever", "assistant")
212
- builder.add_conditional_edges(
213
- "assistant",
214
- tools_condition,
215
- )
216
- builder.add_edge("tools", "assistant")
217
 
218
- # Compile graph
219
- return builder.compile()
220
 
221
- # test
222
  if __name__ == "__main__":
223
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
224
- # Build the graph
225
- graph = build_graph(provider="groq")
226
- # Run the graph
227
- messages = [HumanMessage(content=question)]
228
- messages = graph.invoke({"messages": messages})
229
- for m in messages["messages"]:
230
- m.pretty_print()
 
 
1
  import os
2
  from dotenv import load_dotenv
3
+ from supabase import create_client
4
+ from supabase.client import Client
5
  from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import tools_condition, ToolNode
7
+ from langchain_core.messages import SystemMessage, HumanMessage
8
+ from langchain_core.tools import tool
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_groq import ChatGroq
11
  from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
12
  from langchain_community.tools.tavily_search import TavilySearchResults
13
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
 
14
  from langchain_community.vectorstores import SupabaseVectorStore
 
 
15
  from langchain.tools.retriever import create_retriever_tool
 
16
 
17
  load_dotenv()
18
 
19
+ # Check environment variables
20
+ SUPABASE_URL = os.environ.get("SUPABASE_URL")
21
+ SUPABASE_SERVICE_KEY = os.environ.get("SUPABASE_SERVICE_KEY")
22
 
23
+ print(f"SUPABASE_URL: {SUPABASE_URL[:10]}..." if SUPABASE_URL else "SUPABASE_URL not set")
24
+ print(f"SUPABASE_SERVICE_KEY: {SUPABASE_SERVICE_KEY[:10]}..." if SUPABASE_SERVICE_KEY else "SUPABASE_SERVICE_KEY not set")
 
 
 
 
25
 
 
 
 
26
 
27
+ def get_supabase_client():
28
+ if not SUPABASE_URL or not SUPABASE_SERVICE_KEY:
29
+ raise ValueError("Supabase environment variables are missing.")
30
+ return create_client(SUPABASE_URL, SUPABASE_SERVICE_KEY)
31
 
32
 
33
  @tool
34
  def multiply(a: int, b: int) -> int:
 
 
 
 
 
 
35
  return a * b
36
 
37
  @tool
38
  def add(a: int, b: int) -> int:
 
 
 
 
 
 
39
  return a + b
40
 
41
  @tool
42
  def subtract(a: int, b: int) -> int:
 
 
 
 
 
 
43
  return a - b
44
 
45
  @tool
46
  def divide(a: int, b: int) -> int:
 
 
 
 
 
 
47
  if b == 0:
48
  raise ValueError("Cannot divide by zero.")
49
  return a / b
50
 
51
  @tool
52
  def modulus(a: int, b: int) -> int:
 
 
 
 
 
 
53
  return a % b
54
 
55
  @tool
56
  def wiki_search(query: str) -> str:
57
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
58
+ return "\n\n---\n\n".join([doc.page_content for doc in docs])
 
 
 
 
 
 
 
 
 
59
 
60
  @tool
61
  def web_search(query: str) -> str:
62
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
63
+ return "\n\n---\n\n".join([doc.page_content for doc in docs])
 
 
 
 
 
 
 
 
 
64
 
65
  @tool
66
  def arvix_search(query: str) -> str:
67
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
68
+ return "\n\n---\n\n".join([doc.page_content[:1000] for doc in docs])
 
 
 
 
 
 
 
 
 
69
 
70
 
71
+ tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search]
72
 
 
73
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
74
  system_prompt = f.read()
75
 
 
76
  sys_msg = SystemMessage(content=system_prompt)
77
 
78
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  def build_graph(provider: str = "groq"):
80
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
81
+ supabase = get_supabase_client()
82
+ vector_store = SupabaseVectorStore(
83
+ client=supabase,
84
+ embedding=embeddings,
85
+ table_name="documents",
86
+ query_name="match_documents_langchain",
87
+ )
88
+ retriever_tool = create_retriever_tool(
89
+ retriever=vector_store.as_retriever(),
90
+ name="Question Search",
91
+ description="A tool to retrieve similar questions from a vector store.",
92
+ )
93
+
94
  if provider == "google":
 
95
  llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
96
  elif provider == "groq":
97
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
98
  elif provider == "huggingface":
 
99
  llm = ChatHuggingFace(
100
  llm=HuggingFaceEndpoint(
101
  url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
103
  ),
104
  )
105
  else:
106
+ raise ValueError("Invalid provider specified")
107
+
108
  llm_with_tools = llm.bind_tools(tools)
109
 
 
110
  def assistant(state: MessagesState):
 
111
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
112
+
113
  def retriever(state: MessagesState):
114
+ similar = vector_store.similarity_search(state["messages"][0].content)
115
+ msg = HumanMessage(content=f"Similar question reference:\n\n{similar[0].page_content}")
116
+ return {"messages": [sys_msg] + state["messages"] + [msg]}
117
+
118
+ graph = StateGraph(MessagesState)
119
+ graph.add_node("retriever", retriever)
120
+ graph.add_node("assistant", assistant)
121
+ graph.add_node("tools", ToolNode(tools))
122
+ graph.add_edge(START, "retriever")
123
+ graph.add_edge("retriever", "assistant")
124
+ graph.add_conditional_edges("assistant", tools_condition)
125
+ graph.add_edge("tools", "assistant")
126
+
127
+ return graph.compile()
 
 
 
 
128
 
 
 
129
 
 
130
  if __name__ == "__main__":
131
+ g = build_graph("groq")
132
+ question = "When was Aquinas added to Wikipedia page on double effect?"
133
+ output = g.invoke({"messages": [HumanMessage(content=question)]})
134
+ for msg in output["messages"]:
135
+ msg.pretty_print()