File size: 15,426 Bytes
a7e19d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
from setup import *
import gradio as gr 
from setup import *
from gradio_embedding import*


from typing import Annotated, Sequence, Literal, List, Dict
from typing_extensions import TypedDict
from langgraph.graph import MessagesState
from langchain_core.documents import Document
from pydantic import BaseModel, Field

from langchain_core.messages import BaseMessage, AnyMessage
from langgraph.graph.message import add_messages
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.runnables.config import RunnableConfig
import uuid
from langgraph.store.base import BaseStore
from langchain_core.prompts import ChatPromptTemplate,SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langgraph.graph import END, StateGraph, START
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore

persist_directory = "D:\\Education\\AI\\AI-Agents\\Agentic-RAG"

loaded_db = Chroma(
  persist_directory=persist_directory,
  embedding_function=embeddings,
  collection_name='sagemaker-chroma'
)

retriever = vector_store.as_retriever()




in_memory_store = InMemoryStore()

class AgentState(MessagesState):
  messages: Annotated[list, add_messages]
  history: List[AnyMessage]
  context: List[Document]
  length : int
  query : str
  summary : str




def agent(state, config: RunnableConfig, store: BaseStore):

  print("----CALL AGENT----")
  messages = state['messages']
  context = state['context']
  previous_conversation_summary = state['summary']
  last_messages=state["messages"][-4:]

  system_template = '''You are a friendly and knowledgeable conversational assistant with memory specializing in AWS SageMaker. Your job is to answer questions about SageMaker clearly, accurately, and in a human-like, natural way—like a helpful teammate. Keep your tone polite, engaging, and professional.
    You can help with:
    SageMaker features, pricing, and capabilities
    Model training, deployment, tuning, and monitoring in SageMaker
    Integration with other AWS services in the context of SageMaker
    Troubleshooting common SageMaker issues

    You must not answer questions unrelated to AWS SageMaker. If asked, briefly and politely respond with something like:
    "Sorry, I can only help with AWS SageMaker. Let me know if you have any questions about that!"

    Take into consideration the summary of the conversation so far and last 4 messages
    Summary of the conversation so far: {previous_conversation_summary}
    last 4 messages : {last_messages}

    Also use the context from the docs retrieved for the user query:
    {context}
    
    Here is the memory (it may be empty): {memory}"""
    Keep responses concise, correct, and helpful. Make the conversation feel smooth and human, like chatting with a skilled colleague who knows SageMaker inside out.'''
  

    
  # Get the user ID from the config
  user_id = config["configurable"]["user_id"]

  # Retrieve memory from the store
  namespace = ("memory", user_id)
  key = "user_memory"
  existing_memory = store.get(namespace, key)

  # Extract the actual memory content if it exists and add a prefix
  if existing_memory:
      # Value is a dictionary with a memory key
      existing_memory_content = existing_memory.value.get('memory')
  else:
      existing_memory_content = "No existing memory found."


  system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)

  human_template = '''User last reply: {user_reply}'''
  human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 

  chat_prompt = ChatPromptTemplate.from_messages([
    system_message_prompt,
    human_message_prompt
  ])

  formatted_messages = chat_prompt.format_messages(
    previous_conversation_summary=previous_conversation_summary,
    last_messages = last_messages,
    context = context,
    memory = existing_memory_content,
    user_reply = messages[-1].content
    )
  print('---------Message------', messages)

  response = llm.invoke(formatted_messages)

  return {'messages':[response]}


def rewrite_query(state):

  print("----REWRITE QUERY----")
  previous_summary = state['summary']
  last_messages=state["history"][-4:]
  sys_msg_query = '''TASK: Rewrite the user's query to improve retrieval performance.
    CONTEXT:
    - Conversation Summary: {previous_summary}
    - User's Last Message: {last_messages}

    INSTRUCTIONS:
    1. Identify the core information need in the user's message
    2. Extract key entities, concepts, and relationships
    3. Add relevant context from the conversation summary
    4. Remove conversational fillers and ambiguous references
    5. Use specific terminology that would match relevant documents
    6. Expand abbreviations and clarify ambiguous terms
    7. Format as a concise, search-optimized query

    Your rewritten query should:
    - Maintain the original intent
    - Be self-contained (not require conversation context to understand)
    - Include specific details that would match relevant documents
    - Be under 100 words

    OUTPUT FORMAT:
    Rewritten Query: '''

  rewritten = llm.invoke([SystemMessage(content=sys_msg_query.format(previous_summary=previous_summary,last_messages=last_messages))])
  rewritten_query = rewritten.content
  print('---QUERY---', rewritten_query)

  return {'query': rewritten_query}


def summary_function(state):
  previous_summary = state['summary']
  last_messages=state["history"][-4:]

  sys_msg_summary = '''You are an AI that creates concise summaries of chat conversations. When summarizing:
  1. Capture all named entities (people, organizations, products, locations) and their relationships.
  2. Preserve explicit information, technical terms, and quantitative data.
  3. Identify the conversation's intent, requirements, and underlying needs.
  4. Incorporate previous summaries, resolving contradictions and updating information.
  5. Structure information logically, omitting small talk while retaining critical details.

  Your summary should begin with the conversation purpose, include all key points, and end with the conversation outcome or status. Remain neutral and accurate, ensuring someone can understand what happened without reading the entire transcript.

  Previous summary:
  {previous_summary}

  Last 4 messages:
  {last_messages}

  Summary:

  '''

  summarised = llm.invoke([SystemMessage(content=sys_msg_summary.format(previous_summary=previous_summary,last_messages=last_messages))])
  summarised_content = summarised.content
  print('SUUUUUUU', summarised_content)
  
  return {'summary' : summarised_content}


def summary_or_not(state):
  if len(state['history']) % 4 == 0:
    return True
  else:
    return False
  
def write_memory(state: MessagesState, config: RunnableConfig, store: BaseStore):

  """Reflect on the chat history and save a memory to the store."""
  CREATE_MEMORY_INSTRUCTION = """"You are collecting information about the user to personalize your responses.

  CURRENT USER INFORMATION:
  {memory}

  INSTRUCTIONS:
  1. Review the chat history below carefully
  2. Identify new information about the user, such as:
    - Personal details (name, location)
    - Preferences (likes, dislikes)
    - Interests and hobbies
    - Past experiences
    - Goals or future plans
  3. Merge any new information with existing memory
  4. Format the memory as a clear, bulleted list
  5. If new information conflicts with existing memory, keep the most recent version

  Remember: Only include factual information directly stated by the user. Do not make assumptions or inferences.

  Based on the chat history below, please update the user information:"""

  # Get the user ID from the config
  user_id = config["configurable"]["user_id"]

  # Retrieve existing memory from the store
  namespace = ("memory", user_id)
  existing_memory = store.get(namespace, "user_memory")
      
  # Extract the memory
  if existing_memory:
      existing_memory_content = existing_memory.value.get('memory')
  else:
      existing_memory_content = "No existing memory found."

  # Format the memory in the system prompt
  system_msg = CREATE_MEMORY_INSTRUCTION.format(memory=existing_memory_content)
  new_memory = llm.invoke([SystemMessage(content=system_msg)]+state['messages'])

  # Overwrite the existing memory in the store 
  key = "user_memory"

  # Write value as a dictionary with a memory key
  store.put(namespace, key, {"memory": new_memory.content})


def retrieve_or_not(state) -> Literal["yes", "no"]:

  print("----RETREIVE or NOT----")
  messages = state["messages"]
  previous_summary = state['summary']
  last_messages=state["history"][-4:]
  user_reply = messages[-1].content

  sys_msg = '''You are a specialized decision-making system that evaluates whether retrieval is needed for the current conversation.

    Your only task is to determine if external information retrieval is necessary based on:
    1. The user's most recent message
    2. Recent conversation turns (if provided)
    3. Conversation summary (if provided)

    **You MUST respond ONLY with "yes" or "no".**

    Guidelines for your decision:
    - Reply "yes" if:
      - The query requests specific factual information (dates, statistics, events, etc.)
      - The query asks about real-world entities, events, or concepts that require precise information
      - The query references documents or data that would need to be retrieved
      - The query asks about recent or current events that may not be in your training data
      - The query explicitly asks for citations, references, or sources
      - The query contains specific dates, locations, or proper nouns that might require additional context
      - The query appears to be searching for specific information

    - Reply "no" if:
      - The query is a clarification about something previously explained
      - The query asks for creative content (stories, poems, etc.)
      - The query asks for general advice or opinions
      - The query is conversational in nature (greetings, thanks, etc.)
      - The query is about general concepts that don't require specific factual information
      - The query can be sufficiently answered with your existing knowledge
      - The query is a simple follow-up that doesn't introduce new topics requiring retrieval

    Remember: Your response must be EXACTLY "yes" or "no" with no additional text, explanation, or punctuation.
    Previous summary: {previous_summary}
    Last 4 messages: {last_messages}
    The user sent the reply as {user_reply}. 
    Should we need retrieval: '''

  class decide(BaseModel):
    '''Decision for retrival - yes or no'''
    decision: str = Field(description="Relevance score 'yes' or 'no'")

  structured_llm= llm.with_structured_output(decide)

  retrive_decision = structured_llm.invoke([SystemMessage(content=sys_msg.format(previous_summary=previous_summary,last_messages=last_messages,user_reply=user_reply))])

  print('-------retrievedecisde---------', retrive_decision)
  if retrive_decision == 'yes':
    return True

  return False

def pass_node(state):
    # Update the fill value
    new_length = len(state["history"])
    
    # Return the updated state without making branch decisions
    return {"length": new_length}


def retrieve_docs(state):
  query = state['query']
  retrieved_docs =  loaded_db.similarity_search(query=query, k=5)
  return {'context' : retrieved_docs}





workflow = StateGraph(AgentState)


workflow.add_node('assistant', agent)
workflow.add_node('summary_fn', summary_function)
workflow.add_node('retrieve_bool', pass_node)
workflow.add_node('rewrite_query',rewrite_query)
workflow.add_node('retriever',retrieve_docs)
workflow.add_node("write_memory", write_memory)


workflow.add_conditional_edges(START, summary_or_not, {True: 'summary_fn', False:'retrieve_bool'})
workflow.add_edge('summary_fn', 'retrieve_bool')
workflow.add_conditional_edges('retrieve_bool', retrieve_or_not, {True: 'rewrite_query', False:'write_memory'})
workflow.add_edge('rewrite_query', 'retriever')
workflow.add_edge('retriever', 'write_memory')
workflow.add_edge('write_memory', 'assistant')
workflow.add_edge('assistant', END)

within_thread_memory = MemorySaver()
across_thread_memory = InMemoryStore()
chat_graph = workflow.compile(checkpointer=within_thread_memory, store=across_thread_memory)


within_thread_memory = MemorySaver()
across_thread_memory = InMemoryStore()
chat_graph = workflow.compile(checkpointer=within_thread_memory, store=across_thread_memory)



import time
import gradio as gr 

def main(conv_history, user_reply):
    # Don't recreate the config each time
    config = {"configurable": {"thread_id": "1", "user_id": "1"}}
    
    # Get the current state
    state = chat_graph.get_state(config).values
    
    # Initialize state if it doesn't exist yet
    if not state:
        state = {
            'messages': [],
            'summary': '',
            'length': 0
        }
    
    # Process the new message using LangGraph
    chat_graph.invoke({
        'messages': [('user', user_reply)],
        'history': state['messages'],
        'summary': state['summary'],
        'length': state['length'],
        'context': 'None'
    }, config)
    
    # Get the updated state after processing
    output = chat_graph.get_state(config)
    new_messages = output.values['messages']
    
    # Initialize conv_history if None
    if conv_history is None:
        conv_history = []
    
    # Get the latest bot message
    bot_message = new_messages[-1].content  # Get the last message
    
    # Stream the response (optional feature)
    streamed_message = ""
    for token in bot_message.split():
        streamed_message += f"{token} "
        yield conv_history + [(user_reply, streamed_message.strip())], " "
        time.sleep(0.05)
    
    # Add the final conversation pair to history
    conv_history.append((user_reply, bot_message))
    
    yield conv_history, " "


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.HTML("<center><h1>FAQ Chatbot! 📂📄</h1><center>")
    gr.Markdown("""##### This AI chatbot🤖 can answer your FAQ questions about """)
    
    with gr.Column():
        chatbot = gr.Chatbot(label='ChatBot')
        user_reply = gr.Textbox(label='Enter your query', placeholder='Type here...')
        
        with gr.Row():
            submit_button = gr.Button("Submit")
            clear_btn = gr.ClearButton([user_reply, chatbot], value='Clear')
            
            # Use the gr.State to store conversation history between refreshes
            state = gr.State([])
            
            # Update the click event to include state
            submit_button.click(
                main, 
                inputs=[chatbot, user_reply], 
                outputs=[chatbot, user_reply]
            )
            
            # Add an event handler for the clear button to reset the LangGraph state
            def reset_langgraph_state():
                config = {"configurable": {"thread_id": "1", "user_id": "1"}}
                # Reset the state in LangGraph (if your implementation supports this)
                # If not, you might need to implement a reset method in your graph
                return []
            
            clear_btn.click(reset_langgraph_state, inputs=[], outputs=[chatbot])

demo.launch()