File size: 3,594 Bytes
b642bac
 
 
 
 
 
e1b3737
b642bac
 
 
 
 
948fcb3
 
 
b642bac
948fcb3
e1b3737
b642bac
 
 
 
 
 
 
 
 
 
 
948fcb3
 
 
 
b642bac
 
 
 
e1b3737
b642bac
 
 
 
 
 
 
 
e1b3737
 
 
 
 
 
 
 
 
 
 
b642bac
 
 
 
 
 
e1b3737
b642bac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""LangGraph Agent"""
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace, HuggingFaceEmbeddings
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.globals import set_debug
from langchain_groq import ChatGroq
from tools.search_tools import web_search, arvix_search, wiki_search
from tools.math_tools import multiply, add, subtract, divide
# from supabase.client import Client, create_client
# from langchain.tools.retriever import create_retriever_tool
# from langchain_community.vectorstores import SupabaseVectorStore
import json
from tools.multimodal_tools import extract_text, analyze_image_tool, analyze_audio_tool
from langchain_google_genai import ChatGoogleGenerativeAI

# set_debug(True)
load_dotenv()

tools = [
    multiply,
    add,
    subtract,
    divide,
    web_search,
    wiki_search,
    arvix_search,
    extract_text,
    analyze_image_tool,
    analyze_audio_tool
]

def build_graph():
    hf_token = os.getenv("HF_TOKEN")
    api_key = os.getenv("GEMINI_API_KEY")
    # llm = HuggingFaceEndpoint(
    #     repo_id="Qwen/Qwen2.5-Coder-32B-Instruct",
    #     huggingfacehub_api_token=hf_token,
    # )
    
    # chat = ChatHuggingFace(llm=llm, verbose=True)
    # llm_with_tools = chat.bind_tools(tools)

    # llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
    # llm_with_tools = llm.bind_tools(tools)

    chat = ChatGoogleGenerativeAI(
        model= "gemini-2.5-pro-preview-05-06",
        temperature=0,
        max_retries=2,
        google_api_key=api_key,
        thinking_budget= 0
    )
    chat_with_tools = chat.bind_tools(tools)

    def assistant(state: MessagesState):
        sys_msg = "You are a helpful assistant with access to tools. Understand user requests accurately. Use your tools when needed to answer effectively. Strictly follow all user instructions and constraints." \
        "Pay attention: your output needs to contain only the final answer without any reasoning since it will be strictly evaluated against a dataset which contains only the specific response." \
        "Your final output needs to be just the string or integer containing the answer, not an array or technical stuff."
        return {
            "messages": [chat_with_tools.invoke([sys_msg] + state["messages"])],
        }

    ## The graph
    builder = StateGraph(MessagesState)

    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))

    builder.add_edge(START, "assistant")
    builder.add_conditional_edges(
        "assistant",
        # If the latest message requires a tool, route to tools
        # Otherwise, provide a direct response
        tools_condition,
    )
    builder.add_edge("tools", "assistant")
    return builder.compile()

# test
if __name__ == "__main__":

    graph = build_graph()
    with open('sample.jsonl', 'r') as jsonl_file:
        json_list = list(jsonl_file)
    
    start = 10 #revisit 5, 8, 
    end = start + 1
    for json_str in json_list[start:end]:
        json_data = json.loads(json_str)
        print(f"Question::::::::: {json_data['Question']}")
        print(f"Final answer::::: {json_data['Final answer']}")
    
        question = json_data['Question']
        messages = [HumanMessage(content=question)]
        messages = graph.invoke({"messages": messages})
        for m in messages["messages"]:
            m.pretty_print()