aelyazid commited on
Commit
3898a54
·
verified ·
1 Parent(s): 520bb1b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +116 -668
agent.py CHANGED
@@ -1,36 +1,15 @@
 
1
  import os
2
  from dotenv import load_dotenv
3
- from typing import List, Dict, Any, Optional
4
- import tempfile
5
- import re
6
- import json
7
- import requests
8
- from urllib.parse import urlparse
9
- import pytesseract
10
- from PIL import Image, ImageDraw, ImageFont, ImageEnhance, ImageFilter
11
- import cmath
12
- import pandas as pd
13
- import uuid
14
- import numpy as np
15
- from code_interpreter import CodeInterpreter
16
-
17
- interpreter_instance = CodeInterpreter()
18
-
19
- from image_processing import *
20
-
21
- """Langraph"""
22
  from langgraph.graph import START, StateGraph, MessagesState
 
 
 
 
 
23
  from langchain_community.tools.tavily_search import TavilySearchResults
24
  from langchain_community.document_loaders import WikipediaLoader
25
  from langchain_community.document_loaders import ArxivLoader
26
- from langgraph.prebuilt import ToolNode, tools_condition
27
- from langchain_google_genai import ChatGoogleGenerativeAI
28
- from langchain_groq import ChatGroq
29
- from langchain_huggingface import (
30
- ChatHuggingFace,
31
- HuggingFaceEndpoint,
32
- HuggingFaceEmbeddings,
33
- )
34
  from langchain_community.vectorstores import SupabaseVectorStore
35
  from langchain_core.messages import SystemMessage, HumanMessage
36
  from langchain_core.tools import tool
@@ -39,639 +18,118 @@ from supabase.client import Client, create_client
39
 
40
  load_dotenv()
41
 
42
- ### =============== BROWSER TOOLS =============== ###
43
-
44
-
45
- @tool
46
- def wiki_search(query: str) -> str:
47
- """Search Wikipedia for a query and return maximum 2 results.
48
- Args:
49
- query: The search query."""
50
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
51
- formatted_search_docs = "\n\n---\n\n".join(
52
- [
53
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
54
- for doc in search_docs
55
- ]
56
- )
57
- return {"wiki_results": formatted_search_docs}
58
-
59
-
60
- @tool
61
- def web_search(query: str) -> str:
62
- """Search Tavily for a query and return maximum 3 results.
63
- Args:
64
- query: The search query."""
65
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
66
- formatted_search_docs = "\n\n---\n\n".join(
67
- [
68
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
69
- for doc in search_docs
70
- ]
71
- )
72
- return {"web_results": formatted_search_docs}
73
-
74
-
75
- @tool
76
- def arxiv_search(query: str) -> str:
77
- """Search Arxiv for a query and return maximum 3 result.
78
- Args:
79
- query: The search query."""
80
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
81
- formatted_search_docs = "\n\n---\n\n".join(
82
- [
83
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
84
- for doc in search_docs
85
- ]
86
- )
87
- return {"arxiv_results": formatted_search_docs}
88
-
89
-
90
- ### =============== CODE INTERPRETER TOOLS =============== ###
91
-
92
-
93
  @tool
94
- def execute_code_multilang(code: str, language: str = "python") -> str:
95
- """Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
96
  Args:
97
- code (str): The source code to execute.
98
- language (str): The language of the code. Supported: "python", "bash", "sql", "c", "java".
99
- Returns:
100
- A string summarizing the execution results (stdout, stderr, errors, plots, dataframes if any).
101
- """
102
- supported_languages = ["python", "bash", "sql", "c", "java"]
103
- language = language.lower()
104
-
105
- if language not in supported_languages:
106
- return f"❌ Unsupported language: {language}. Supported languages are: {', '.join(supported_languages)}"
107
-
108
- result = interpreter_instance.execute_code(code, language=language)
109
-
110
- response = []
111
-
112
- if result["status"] == "success":
113
- response.append(f"✅ Code executed successfully in **{language.upper()}**")
114
-
115
- if result.get("stdout"):
116
- response.append(
117
- "\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```"
118
- )
119
-
120
- if result.get("stderr"):
121
- response.append(
122
- "\n**Standard Error (if any):**\n```\n"
123
- + result["stderr"].strip()
124
- + "\n```"
125
- )
126
-
127
- if result.get("result") is not None:
128
- response.append(
129
- "\n**Execution Result:**\n```\n"
130
- + str(result["result"]).strip()
131
- + "\n```"
132
- )
133
-
134
- if result.get("dataframes"):
135
- for df_info in result["dataframes"]:
136
- response.append(
137
- f"\n**DataFrame `{df_info['name']}` (Shape: {df_info['shape']})**"
138
- )
139
- df_preview = pd.DataFrame(df_info["head"])
140
- response.append("First 5 rows:\n```\n" + str(df_preview) + "\n```")
141
-
142
- if result.get("plots"):
143
- response.append(
144
- f"\n**Generated {len(result['plots'])} plot(s)** (Image data returned separately)"
145
- )
146
-
147
- else:
148
- response.append(f"❌ Code execution failed in **{language.upper()}**")
149
- if result.get("stderr"):
150
- response.append(
151
- "\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```"
152
- )
153
-
154
- return "\n".join(response)
155
-
156
-
157
- ### =============== MATHEMATICAL TOOLS =============== ###
158
-
159
-
160
- @tool
161
- def multiply(a: float, b: float) -> float:
162
- """
163
- Multiplies two numbers.
164
- Args:
165
- a (float): the first number
166
- b (float): the second number
167
  """
168
  return a * b
169
 
170
-
171
  @tool
172
- def add(a: float, b: float) -> float:
173
- """
174
- Adds two numbers.
175
  Args:
176
- a (float): the first number
177
- b (float): the second number
178
  """
179
  return a + b
180
 
181
-
182
  @tool
183
- def subtract(a: float, b: float) -> int:
184
- """
185
- Subtracts two numbers.
186
  Args:
187
- a (float): the first number
188
- b (float): the second number
189
  """
190
  return a - b
191
 
192
-
193
  @tool
194
- def divide(a: float, b: float) -> float:
195
- """
196
- Divides two numbers.
197
  Args:
198
- a (float): the first float number
199
- b (float): the second float number
200
  """
201
  if b == 0:
202
- raise ValueError("Cannot divided by zero.")
203
  return a / b
204
 
205
-
206
  @tool
207
  def modulus(a: int, b: int) -> int:
208
- """
209
- Get the modulus of two numbers.
210
  Args:
211
- a (int): the first number
212
- b (int): the second number
213
  """
214
  return a % b
215
 
216
-
217
- @tool
218
- def power(a: float, b: float) -> float:
219
- """
220
- Get the power of two numbers.
221
- Args:
222
- a (float): the first number
223
- b (float): the second number
224
- """
225
- return a**b
226
-
227
-
228
- @tool
229
- def square_root(a: float) -> float | complex:
230
- """
231
- Get the square root of a number.
232
- Args:
233
- a (float): the number to get the square root of
234
- """
235
- if a >= 0:
236
- return a**0.5
237
- return cmath.sqrt(a)
238
-
239
-
240
- ### =============== DOCUMENT PROCESSING TOOLS =============== ###
241
-
242
-
243
- @tool
244
- def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
245
- """
246
- Save content to a file and return the path.
247
- Args:
248
- content (str): the content to save to the file
249
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
250
- """
251
- temp_dir = tempfile.gettempdir()
252
- if filename is None:
253
- temp_file = tempfile.NamedTemporaryFile(delete=False, dir=temp_dir)
254
- filepath = temp_file.name
255
- else:
256
- filepath = os.path.join(temp_dir, filename)
257
-
258
- with open(filepath, "w") as f:
259
- f.write(content)
260
-
261
- return f"File saved to {filepath}. You can read this file to process its contents."
262
-
263
-
264
- @tool
265
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
266
- """
267
- Download a file from a URL and save it to a temporary location.
268
- Args:
269
- url (str): the URL of the file to download.
270
- filename (str, optional): the name of the file. If not provided, a random name file will be created.
271
- """
272
- try:
273
- # Parse URL to get filename if not provided
274
- if not filename:
275
- path = urlparse(url).path
276
- filename = os.path.basename(path)
277
- if not filename:
278
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
279
-
280
- # Create temporary file
281
- temp_dir = tempfile.gettempdir()
282
- filepath = os.path.join(temp_dir, filename)
283
-
284
- # Download the file
285
- response = requests.get(url, stream=True)
286
- response.raise_for_status()
287
-
288
- # Save the file
289
- with open(filepath, "wb") as f:
290
- for chunk in response.iter_content(chunk_size=8192):
291
- f.write(chunk)
292
-
293
- return f"File downloaded to {filepath}. You can read this file to process its contents."
294
- except Exception as e:
295
- return f"Error downloading file: {str(e)}"
296
-
297
-
298
  @tool
299
- def extract_text_from_image(image_path: str) -> str:
300
- """
301
- Extract text from an image using OCR library pytesseract (if available).
302
- Args:
303
- image_path (str): the path to the image file.
304
- """
305
- try:
306
- # Open the image
307
- image = Image.open(image_path)
308
-
309
- # Extract text from the image
310
- text = pytesseract.image_to_string(image)
311
-
312
- return f"Extracted text from image:\n\n{text}"
313
- except Exception as e:
314
- return f"Error extracting text from image: {str(e)}"
315
-
316
-
317
- @tool
318
- def analyze_csv_file(file_path: str, query: str) -> str:
319
- """
320
- Analyze a CSV file using pandas and answer a question about it.
321
- Args:
322
- file_path (str): the path to the CSV file.
323
- query (str): Question about the data
324
- """
325
- try:
326
- # Read the CSV file
327
- df = pd.read_csv(file_path)
328
-
329
- # Run various analyses based on the query
330
- result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
331
- result += f"Columns: {', '.join(df.columns)}\n\n"
332
-
333
- # Add summary statistics
334
- result += "Summary statistics:\n"
335
- result += str(df.describe())
336
-
337
- return result
338
-
339
- except Exception as e:
340
- return f"Error analyzing CSV file: {str(e)}"
341
-
342
-
343
- @tool
344
- def analyze_excel_file(file_path: str, query: str) -> str:
345
- """
346
- Analyze an Excel file using pandas and answer a question about it.
347
- Args:
348
- file_path (str): the path to the Excel file.
349
- query (str): Question about the data
350
- """
351
- try:
352
- # Read the Excel file
353
- df = pd.read_excel(file_path)
354
-
355
- # Run various analyses based on the query
356
- result = (
357
- f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
358
- )
359
- result += f"Columns: {', '.join(df.columns)}\n\n"
360
-
361
- # Add summary statistics
362
- result += "Summary statistics:\n"
363
- result += str(df.describe())
364
-
365
- return result
366
-
367
- except Exception as e:
368
- return f"Error analyzing Excel file: {str(e)}"
369
-
370
-
371
- ### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
372
-
373
-
374
- @tool
375
- def analyze_image(image_base64: str) -> Dict[str, Any]:
376
- """
377
- Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
378
- Args:
379
- image_base64 (str): Base64 encoded image string
380
- Returns:
381
- Dictionary with analysis result
382
- """
383
- try:
384
- img = decode_image(image_base64)
385
- width, height = img.size
386
- mode = img.mode
387
-
388
- if mode in ("RGB", "RGBA"):
389
- arr = np.array(img)
390
- avg_colors = arr.mean(axis=(0, 1))
391
- dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
392
- brightness = avg_colors.mean()
393
- color_analysis = {
394
- "average_rgb": avg_colors.tolist(),
395
- "brightness": brightness,
396
- "dominant_color": dominant,
397
- }
398
- else:
399
- color_analysis = {"note": f"No color analysis for mode {mode}"}
400
-
401
- thumbnail = img.copy()
402
- thumbnail.thumbnail((100, 100))
403
- thumb_path = save_image(thumbnail, "thumbnails")
404
- thumbnail_base64 = encode_image(thumb_path)
405
-
406
- return {
407
- "dimensions": (width, height),
408
- "mode": mode,
409
- "color_analysis": color_analysis,
410
- "thumbnail": thumbnail_base64,
411
- }
412
- except Exception as e:
413
- return {"error": str(e)}
414
-
415
-
416
- @tool
417
- def transform_image(
418
- image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
419
- ) -> Dict[str, Any]:
420
- """
421
- Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
422
- Args:
423
- image_base64 (str): Base64 encoded input image
424
- operation (str): Transformation operation
425
- params (Dict[str, Any], optional): Parameters for the operation
426
- Returns:
427
- Dictionary with transformed image (base64)
428
- """
429
- try:
430
- img = decode_image(image_base64)
431
- params = params or {}
432
-
433
- if operation == "resize":
434
- img = img.resize(
435
- (
436
- params.get("width", img.width // 2),
437
- params.get("height", img.height // 2),
438
- )
439
- )
440
- elif operation == "rotate":
441
- img = img.rotate(params.get("angle", 90), expand=True)
442
- elif operation == "crop":
443
- img = img.crop(
444
- (
445
- params.get("left", 0),
446
- params.get("top", 0),
447
- params.get("right", img.width),
448
- params.get("bottom", img.height),
449
- )
450
- )
451
- elif operation == "flip":
452
- if params.get("direction", "horizontal") == "horizontal":
453
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
454
- else:
455
- img = img.transpose(Image.FLIP_TOP_BOTTOM)
456
- elif operation == "adjust_brightness":
457
- img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
458
- elif operation == "adjust_contrast":
459
- img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
460
- elif operation == "blur":
461
- img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
462
- elif operation == "sharpen":
463
- img = img.filter(ImageFilter.SHARPEN)
464
- elif operation == "grayscale":
465
- img = img.convert("L")
466
- else:
467
- return {"error": f"Unknown operation: {operation}"}
468
-
469
- result_path = save_image(img)
470
- result_base64 = encode_image(result_path)
471
- return {"transformed_image": result_base64}
472
-
473
- except Exception as e:
474
- return {"error": str(e)}
475
-
476
-
477
- @tool
478
- def draw_on_image(
479
- image_base64: str, drawing_type: str, params: Dict[str, Any]
480
- ) -> Dict[str, Any]:
481
- """
482
- Draw shapes (rectangle, circle, line) or text onto an image.
483
  Args:
484
- image_base64 (str): Base64 encoded input image
485
- drawing_type (str): Drawing type
486
- params (Dict[str, Any]): Drawing parameters
487
- Returns:
488
- Dictionary with result image (base64)
489
- """
490
- try:
491
- img = decode_image(image_base64)
492
- draw = ImageDraw.Draw(img)
493
- color = params.get("color", "red")
494
-
495
- if drawing_type == "rectangle":
496
- draw.rectangle(
497
- [params["left"], params["top"], params["right"], params["bottom"]],
498
- outline=color,
499
- width=params.get("width", 2),
500
- )
501
- elif drawing_type == "circle":
502
- x, y, r = params["x"], params["y"], params["radius"]
503
- draw.ellipse(
504
- (x - r, y - r, x + r, y + r),
505
- outline=color,
506
- width=params.get("width", 2),
507
- )
508
- elif drawing_type == "line":
509
- draw.line(
510
- (
511
- params["start_x"],
512
- params["start_y"],
513
- params["end_x"],
514
- params["end_y"],
515
- ),
516
- fill=color,
517
- width=params.get("width", 2),
518
- )
519
- elif drawing_type == "text":
520
- font_size = params.get("font_size", 20)
521
- try:
522
- font = ImageFont.truetype("arial.ttf", font_size)
523
- except IOError:
524
- font = ImageFont.load_default()
525
- draw.text(
526
- (params["x"], params["y"]),
527
- params.get("text", "Text"),
528
- fill=color,
529
- font=font,
530
- )
531
- else:
532
- return {"error": f"Unknown drawing type: {drawing_type}"}
533
-
534
- result_path = save_image(img)
535
- result_base64 = encode_image(result_path)
536
- return {"result_image": result_base64}
537
-
538
- except Exception as e:
539
- return {"error": str(e)}
540
-
541
 
542
  @tool
543
- def generate_simple_image(
544
- image_type: str,
545
- width: int = 500,
546
- height: int = 500,
547
- params: Optional[Dict[str, Any]] = None,
548
- ) -> Dict[str, Any]:
549
- """
550
- Generate a simple image (gradient, noise, pattern, chart).
551
  Args:
552
- image_type (str): Type of image
553
- width (int), height (int)
554
- params (Dict[str, Any], optional): Specific parameters
555
- Returns:
556
- Dictionary with generated image (base64)
557
- """
558
- try:
559
- params = params or {}
560
-
561
- if image_type == "gradient":
562
- direction = params.get("direction", "horizontal")
563
- start_color = params.get("start_color", (255, 0, 0))
564
- end_color = params.get("end_color", (0, 0, 255))
565
-
566
- img = Image.new("RGB", (width, height))
567
- draw = ImageDraw.Draw(img)
568
-
569
- if direction == "horizontal":
570
- for x in range(width):
571
- r = int(
572
- start_color[0] + (end_color[0] - start_color[0]) * x / width
573
- )
574
- g = int(
575
- start_color[1] + (end_color[1] - start_color[1]) * x / width
576
- )
577
- b = int(
578
- start_color[2] + (end_color[2] - start_color[2]) * x / width
579
- )
580
- draw.line([(x, 0), (x, height)], fill=(r, g, b))
581
- else:
582
- for y in range(height):
583
- r = int(
584
- start_color[0] + (end_color[0] - start_color[0]) * y / height
585
- )
586
- g = int(
587
- start_color[1] + (end_color[1] - start_color[1]) * y / height
588
- )
589
- b = int(
590
- start_color[2] + (end_color[2] - start_color[2]) * y / height
591
- )
592
- draw.line([(0, y), (width, y)], fill=(r, g, b))
593
-
594
- elif image_type == "noise":
595
- noise_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
596
- img = Image.fromarray(noise_array, "RGB")
597
-
598
- else:
599
- return {"error": f"Unsupported image_type {image_type}"}
600
-
601
- result_path = save_image(img)
602
- result_base64 = encode_image(result_path)
603
- return {"generated_image": result_base64}
604
-
605
- except Exception as e:
606
- return {"error": str(e)}
607
-
608
 
609
  @tool
610
- def combine_images(
611
- images_base64: List[str], operation: str, params: Optional[Dict[str, Any]] = None
612
- ) -> Dict[str, Any]:
613
- """
614
- Combine multiple images (collage, stack, blend).
615
  Args:
616
- images_base64 (List[str]): List of base64 images
617
- operation (str): Combination type
618
- params (Dict[str, Any], optional)
619
- Returns:
620
- Dictionary with combined image (base64)
621
- """
622
- try:
623
- images = [decode_image(b64) for b64 in images_base64]
624
- params = params or {}
625
-
626
- if operation == "stack":
627
- direction = params.get("direction", "horizontal")
628
- if direction == "horizontal":
629
- total_width = sum(img.width for img in images)
630
- max_height = max(img.height for img in images)
631
- new_img = Image.new("RGB", (total_width, max_height))
632
- x = 0
633
- for img in images:
634
- new_img.paste(img, (x, 0))
635
- x += img.width
636
- else:
637
- max_width = max(img.width for img in images)
638
- total_height = sum(img.height for img in images)
639
- new_img = Image.new("RGB", (max_width, total_height))
640
- y = 0
641
- for img in images:
642
- new_img.paste(img, (0, y))
643
- y += img.height
644
- else:
645
- return {"error": f"Unsupported combination operation {operation}"}
646
-
647
- result_path = save_image(new_img)
648
- result_base64 = encode_image(result_path)
649
- return {"combined_image": result_base64}
650
 
651
- except Exception as e:
652
- return {"error": str(e)}
653
 
654
 
655
  # load the system prompt from the file
656
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
657
  system_prompt = f.read()
658
- print(system_prompt)
659
 
660
  # System message
661
  sys_msg = SystemMessage(content=system_prompt)
662
 
663
  # build a retriever
664
- embeddings = HuggingFaceEmbeddings(
665
- model_name="sentence-transformers/all-mpnet-base-v2"
666
- ) # dim=768
667
  supabase: Client = create_client(
668
- os.environ.get("SUPABASE_URL"), os.environ.get("SUPABASE_SERVICE_ROLE_KEY")
669
- )
670
  vector_store = SupabaseVectorStore(
671
  client=supabase,
672
- embedding=embeddings,
673
- table_name="documents2",
674
- query_name="match_documents_2",
675
  )
676
  create_retriever_tool = create_retriever_tool(
677
  retriever=vector_store.as_retriever(),
@@ -680,53 +138,38 @@ create_retriever_tool = create_retriever_tool(
680
  )
681
 
682
 
 
683
  tools = [
684
- web_search,
685
- wiki_search,
686
- arxiv_search,
687
  multiply,
688
  add,
689
  subtract,
690
  divide,
691
  modulus,
692
- power,
693
- square_root,
694
- save_and_read_file,
695
- download_file_from_url,
696
- extract_text_from_image,
697
- analyze_csv_file,
698
- analyze_excel_file,
699
- execute_code_multilang,
700
- analyze_image,
701
- transform_image,
702
- draw_on_image,
703
- generate_simple_image,
704
- combine_images,
705
  ]
706
 
707
-
708
  # Build graph function
709
- def build_graph(provider: str = "groq"):
710
  """Build the graph"""
711
  # Load environment variables from .env file
712
- if provider == "groq":
 
 
 
713
  # Groq https://console.groq.com/docs/models
714
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
715
  elif provider == "huggingface":
716
  # TODO: Add huggingface endpoint
717
  llm = ChatHuggingFace(
718
  llm=HuggingFaceEndpoint(
719
- repo_id="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
720
- task="text-generation", # for chat‐style use “text-generation”
721
- max_new_tokens=1024,
722
- do_sample=False,
723
- repetition_penalty=1.03,
724
  temperature=0,
725
  ),
726
- verbose=True,
727
  )
728
  else:
729
- raise ValueError("Invalid provider. Choose 'groq' or 'huggingface'.")
730
  # Bind tools to LLM
731
  llm_with_tools = llm.bind_tools(tools)
732
 
@@ -734,42 +177,47 @@ def build_graph(provider: str = "groq"):
734
  def assistant(state: MessagesState):
735
  """Assistant node"""
736
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
737
 
738
  def retriever(state: MessagesState):
739
- """Retriever node"""
740
- similar_question = vector_store.similarity_search(state["messages"][0].content)
741
 
742
- if similar_question: # Check if the list is not empty
743
- example_msg = HumanMessage(
744
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
745
- )
746
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
747
  else:
748
- # Handle the case when no similar questions are found
749
- return {"messages": [sys_msg] + state["messages"]}
 
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
  builder = StateGraph(MessagesState)
752
  builder.add_node("retriever", retriever)
753
- builder.add_node("assistant", assistant)
754
- builder.add_node("tools", ToolNode(tools))
755
- builder.add_edge(START, "retriever")
756
- builder.add_edge("retriever", "assistant")
757
- builder.add_conditional_edges(
758
- "assistant",
759
- tools_condition,
760
- )
761
- builder.add_edge("tools", "assistant")
762
-
763
- # Compile graph
764
- return builder.compile()
765
-
766
 
767
- # test
768
- if __name__ == "__main__":
769
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
770
- graph = build_graph(provider="groq")
771
- messages = [HumanMessage(content=question)]
772
- messages = graph.invoke({"messages": messages})
773
- for m in messages["messages"]:
774
- m.pretty_print()
775
 
 
 
 
1
+ """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
  from langchain_community.tools.tavily_search import TavilySearchResults
11
  from langchain_community.document_loaders import WikipediaLoader
12
  from langchain_community.document_loaders import ArxivLoader
 
 
 
 
 
 
 
 
13
  from langchain_community.vectorstores import SupabaseVectorStore
14
  from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
 
18
 
19
  load_dotenv()
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
  Args:
25
+ a: first int
26
+ b: second int
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
  return a * b
29
 
 
30
  @tool
31
+ def add(a: int, b: int) -> int:
32
+ """Add two numbers.
33
+
34
  Args:
35
+ a: first int
36
+ b: second int
37
  """
38
  return a + b
39
 
 
40
  @tool
41
+ def subtract(a: int, b: int) -> int:
42
+ """Subtract two numbers.
43
+
44
  Args:
45
+ a: first int
46
+ b: second int
47
  """
48
  return a - b
49
 
 
50
  @tool
51
+ def divide(a: int, b: int) -> int:
52
+ """Divide two numbers.
53
+
54
  Args:
55
+ a: first int
56
+ b: second int
57
  """
58
  if b == 0:
59
+ raise ValueError("Cannot divide by zero.")
60
  return a / b
61
 
 
62
  @tool
63
  def modulus(a: int, b: int) -> int:
64
+ """Get the modulus of two numbers.
65
+
66
  Args:
67
+ a: first int
68
+ b: second int
69
  """
70
  return a % b
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  @tool
73
+ def wiki_search(query: str) -> str:
74
+ """Search Wikipedia for a query and return maximum 2 results.
75
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  Args:
77
+ query: The search query."""
78
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
79
+ formatted_search_docs = "\n\n---\n\n".join(
80
+ [
81
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
82
+ for doc in search_docs
83
+ ])
84
+ return {"wiki_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Tavily for a query and return maximum 3 results.
89
+
 
 
 
 
 
90
  Args:
91
+ query: The search query."""
92
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
+
 
 
104
  Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
 
 
114
 
115
 
116
  # load the system prompt from the file
117
  with open("system_prompt.txt", "r", encoding="utf-8") as f:
118
  system_prompt = f.read()
 
119
 
120
  # System message
121
  sys_msg = SystemMessage(content=system_prompt)
122
 
123
  # build a retriever
124
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
 
 
125
  supabase: Client = create_client(
126
+ os.environ.get("SUPABASE_URL"),
127
+ os.environ.get("SUPABASE_SERVICE_KEY"))
128
  vector_store = SupabaseVectorStore(
129
  client=supabase,
130
+ embedding= embeddings,
131
+ table_name="documents",
132
+ query_name="match_documents_langchain",
133
  )
134
  create_retriever_tool = create_retriever_tool(
135
  retriever=vector_store.as_retriever(),
 
138
  )
139
 
140
 
141
+
142
  tools = [
 
 
 
143
  multiply,
144
  add,
145
  subtract,
146
  divide,
147
  modulus,
148
+ wiki_search,
149
+ web_search,
150
+ arvix_search,
 
 
 
 
 
 
 
 
 
 
151
  ]
152
 
 
153
  # Build graph function
154
+ def build_graph(provider: str = "google"):
155
  """Build the graph"""
156
  # Load environment variables from .env file
157
+ if provider == "google":
158
+ # Google Gemini
159
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
160
+ elif provider == "groq":
161
  # Groq https://console.groq.com/docs/models
162
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
163
  elif provider == "huggingface":
164
  # TODO: Add huggingface endpoint
165
  llm = ChatHuggingFace(
166
  llm=HuggingFaceEndpoint(
167
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
 
 
 
 
168
  temperature=0,
169
  ),
 
170
  )
171
  else:
172
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
173
  # Bind tools to LLM
174
  llm_with_tools = llm.bind_tools(tools)
175
 
 
177
  def assistant(state: MessagesState):
178
  """Assistant node"""
179
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
180
+
181
+ # def retriever(state: MessagesState):
182
+ # """Retriever node"""
183
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
184
+ #example_msg = HumanMessage(
185
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
186
+ # )
187
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
188
+
189
+ from langchain_core.messages import AIMessage
190
 
191
  def retriever(state: MessagesState):
192
+ query = state["messages"][-1].content
193
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
194
 
195
+ content = similar_doc.page_content
196
+ if "Final answer :" in content:
197
+ answer = content.split("Final answer :")[-1].strip()
 
 
198
  else:
199
+ answer = content.strip()
200
+
201
+ return {"messages": [AIMessage(content=answer)]}
202
+
203
+ # builder = StateGraph(MessagesState)
204
+ #builder.add_node("retriever", retriever)
205
+ #builder.add_node("assistant", assistant)
206
+ #builder.add_node("tools", ToolNode(tools))
207
+ #builder.add_edge(START, "retriever")
208
+ #builder.add_edge("retriever", "assistant")
209
+ #builder.add_conditional_edges(
210
+ # "assistant",
211
+ # tools_condition,
212
+ #)
213
+ #builder.add_edge("tools", "assistant")
214
 
215
  builder = StateGraph(MessagesState)
216
  builder.add_node("retriever", retriever)
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ # Retriever ist Start und Endpunkt
219
+ builder.set_entry_point("retriever")
220
+ builder.set_finish_point("retriever")
 
 
 
 
 
221
 
222
+ # Compile graph
223
+ return builder.compile()