mbudisic commited on
Commit
119237b
Β·
1 Parent(s): 6abc2ec

feat: Refactor ApplicationState and initialize datastore

Browse files

- Refactored the ApplicationState class to streamline initialization and state management.
- Introduced a new datastore management system using DatastoreManager and MemorySaver.
- Updated the on_chat_start function to handle datastore initialization and graph compilation.
- Added a new chainlit.md file for developer onboarding and documentation.

Files changed (2) hide show
  1. app.py +135 -165
  2. chainlit.md +14 -0
app.py CHANGED
@@ -1,7 +1,4 @@
1
  from pstuts_rag.configuration import Configuration
2
- from pstuts_rag.datastore import fill_the_db
3
- from pstuts_rag.graph import build_the_graph
4
- from pstuts_rag.state import PsTutsTeamState
5
  import requests
6
  import asyncio
7
  import json
@@ -12,21 +9,21 @@ import re
12
 
13
  import chainlit as cl
14
  from dotenv import load_dotenv
 
15
  from langchain_core.documents import Document
16
  from langchain_core.language_models import BaseChatModel
17
  from langchain_core.runnables import Runnable
18
- from langchain_openai import ChatOpenAI
19
  from langchain_core.embeddings import Embeddings
20
- from langchain_huggingface import HuggingFaceEmbeddings
21
 
22
 
23
  from langchain_core.messages import HumanMessage, BaseMessage
24
- import langgraph.graph
25
-
26
 
27
- import pstuts_rag.datastore
28
- import pstuts_rag.rag
29
 
 
 
 
 
30
 
31
  import nest_asyncio
32
  from uuid import uuid4
@@ -80,22 +77,10 @@ class ApplicationState:
80
  pointsLoaded: Number of data points loaded into the database
81
  """
82
 
83
- embeddings: Embeddings = None
84
- docs: List[Document] = []
85
- qdrant_client = None
86
- vector_store = None
87
- datastore_manager = None
88
- rag = None
89
- llm: BaseChatModel = None
90
- rag_chain: Runnable = None
91
-
92
- ai_graph: Runnable = None
93
- ai_graph_sketch = None
94
-
95
- tasks: List[asyncio.Task] = []
96
-
97
- hasLoaded: asyncio.Event = asyncio.Event()
98
- pointsLoaded: int = 0
99
 
100
  def __init__(self) -> None:
101
  """
@@ -104,7 +89,7 @@ class ApplicationState:
104
  load_dotenv()
105
  set_api_key_if_not_present("OPENAI_API_KEY")
106
  set_api_key_if_not_present("TAVILY_API_KEY")
107
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
108
  os.environ["LANGCHAIN_PROJECT"] = (
109
  f"AIE - MBUDISIC - HF - CERT - {unique_id}"
110
  )
@@ -112,40 +97,7 @@ class ApplicationState:
112
 
113
 
114
  # Initialize global application state
115
- app_state = ApplicationState()
116
- params = Configuration()
117
- ai_state = PsTutsTeamState(
118
- messages=[],
119
- team_members=[VIDEOARCHIVE, ADOBEHELP],
120
- next="START",
121
- )
122
-
123
-
124
- async def initialize():
125
-
126
- await fill_the_db(app_state)
127
- app_state.ai_graph, app_state.ai_graph_sketch = await build_the_graph(
128
- app_state
129
- )
130
-
131
-
132
- def enter_chain(message: str):
133
- """
134
- Entry point for the agent graph chain.
135
-
136
- Transforms a user message into the state format expected by the agent graph.
137
-
138
- Args:
139
- message: User's input message
140
-
141
- Returns:
142
- Dictionary with the message and team members information
143
- """
144
- results = {
145
- "messages": [HumanMessage(content=message)],
146
- "team_members": [VIDEOARCHIVE, ADOBEHELP],
147
- }
148
- return results
149
 
150
 
151
  @cl.on_chat_start
@@ -156,105 +108,123 @@ async def on_chat_start():
156
  Sets up the language model, vector database components, and spawns tasks
157
  for database population and graph building.
158
  """
159
- app_state.llm = ChatOpenAI(model=params.tool_calling_model, temperature=0)
160
- # Use LangChain's built-in HuggingFaceEmbeddings wrapper
161
- app_state.embeddings = HuggingFaceEmbeddings(
162
- model_name=params.embedding_model
163
- )
164
 
165
- app_state.rag = pstuts_rag.rag.RAGChainInstance(
166
- name="deployed",
167
- qdrant_client=app_state.qdrant_client,
168
- llm=app_state.llm,
169
- embeddings=app_state.embeddings,
170
  )
171
-
172
- app_state.tasks.append(asyncio.create_task(initialize()))
173
-
174
-
175
- def process_response(
176
- response_message: BaseMessage,
177
- ) -> Tuple[str, List[cl.Message]]:
178
- """
179
- Processes a response from the AI agents.
180
-
181
- Extracts the main text and video references from the response,
182
- and creates message elements for displaying video content.
183
-
184
- Args:
185
- response: Response object from the AI agent
186
-
187
- Returns:
188
- Tuple containing the text response and a list of message elements with video references
189
- """
190
- streamed_text = f"[_from: {response_message.name}_]\n"
191
- msg_references = []
192
-
193
- if response_message.name == VIDEOARCHIVE:
194
- text, references = pstuts_rag.rag.RAGChainFactory.unpack_references(
195
- str(response_message.content)
196
- )
197
- streamed_text += text
198
-
199
- if len(references) > 0:
200
- references = json.loads(references)
201
- print(references)
202
-
203
- for ref in references:
204
- msg_references.append(
205
- cl.Message(
206
- content=(
207
- f"Watch {ref['title']} from timestamp "
208
- f"{round(ref['start'] // 60)}m:{round(ref['start'] % 60)}s"
209
- ),
210
- elements=[
211
- cl.Video(
212
- name=ref["title"],
213
- url=f"{ref['source']}#t={ref['start']}",
214
- display="side",
215
- )
216
- ],
217
- )
218
  )
219
- else:
220
- streamed_text += str(response_message.content)
221
 
222
- # Find all URLs in the content
223
- urls = re.findall(
224
- r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[/\w\.-]*(?:\?[/\w\.-=&%]*)?",
225
- str(response_message.content),
 
226
  )
227
- print(urls)
228
- links = []
229
- # Create a list of unique URLs
230
- for idx, u in enumerate(list(set(urls))):
231
-
232
- url = "https://api.microlink.io"
233
- params = {
234
- "url": u,
235
- "screenshot": True,
236
- }
237
-
238
- payload = requests.get(url, params)
239
-
240
- if payload:
241
- print(f"Successful screenshot\n{payload.json()}")
242
- links.append(
243
- cl.Image(
244
- name=f"Website {idx} Preview: {u}",
245
- display="side", # Show in the sidebar
246
- url=payload.json()["data"]["screenshot"]["url"],
247
- )
248
- )
249
 
250
- print(links)
251
- msg_references.append(
252
- cl.Message(
253
- content="\n".join([l.url for l in links]), elements=links
 
254
  )
255
  )
256
 
257
- return streamed_text, msg_references
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
 
260
  @cl.on_message
@@ -268,21 +238,21 @@ async def main(user_cl_message: cl.Message):
268
  Args:
269
  message: User's input message
270
  """
271
- for s in app_state.ai_graph.stream(
272
- user_cl_message.content, {"recursion_limit": 20}
273
- ):
274
- if "__end__" not in s and "supervisor" not in s.keys():
275
- for [node_type, node_response] in s.items():
276
- print(f"Processing {node_type} messages")
277
- for node_message in node_response["messages"]:
278
- print(f"Message {node_message}")
279
- msg = cl.Message(content="")
280
- text, references = process_response(node_message)
281
- for token in [char for char in text]:
282
- await msg.stream_token(token)
283
- await msg.send()
284
- for m in references:
285
- await m.send()
286
 
287
 
288
  if __name__ == "__main__":
 
1
  from pstuts_rag.configuration import Configuration
 
 
 
2
  import requests
3
  import asyncio
4
  import json
 
9
 
10
  import chainlit as cl
11
  from dotenv import load_dotenv
12
+
13
  from langchain_core.documents import Document
14
  from langchain_core.language_models import BaseChatModel
15
  from langchain_core.runnables import Runnable
 
16
  from langchain_core.embeddings import Embeddings
17
+ from langgraph.checkpoint.memory import MemorySaver
18
 
19
 
20
  from langchain_core.messages import HumanMessage, BaseMessage
 
 
21
 
 
 
22
 
23
+ from pstuts_rag.configuration import Configuration
24
+ from pstuts_rag.datastore import DatastoreManager
25
+ from pstuts_rag.rag_for_transcripts import create_transcript_rag_chain
26
+ from pstuts_rag.nodes import initialize
27
 
28
  import nest_asyncio
29
  from uuid import uuid4
 
77
  pointsLoaded: Number of data points loaded into the database
78
  """
79
 
80
+ config: Configuration = Configuration()
81
+ compiled_graph = None
82
+ datastore: DatastoreManager = None
83
+ checkpointer = MemorySaver()
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def __init__(self) -> None:
86
  """
 
89
  load_dotenv()
90
  set_api_key_if_not_present("OPENAI_API_KEY")
91
  set_api_key_if_not_present("TAVILY_API_KEY")
92
+ # os.environ["LANGCHAIN_TRACING_V2"] = "true"
93
  os.environ["LANGCHAIN_PROJECT"] = (
94
  f"AIE - MBUDISIC - HF - CERT - {unique_id}"
95
  )
 
97
 
98
 
99
  # Initialize global application state
100
+ _app_state = ApplicationState()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
 
103
  @cl.on_chat_start
 
108
  Sets up the language model, vector database components, and spawns tasks
109
  for database population and graph building.
110
  """
111
+ global _app_state
 
 
 
 
112
 
113
+ # Initialize datastore using asyncio.to_thread to avoid blocking
114
+ initialize_datastore: bool = _app_state.datastore is None or (
115
+ isinstance(_app_state.datastore, DatastoreManager)
116
+ and _app_state.datastore.count_docs() == 0
 
117
  )
118
+ if initialize_datastore:
119
+ _app_state.datastore = await asyncio.to_thread(
120
+ lambda: DatastoreManager(
121
+ config=_app_state.config
122
+ ).add_completion_callback(
123
+ lambda: cl.run_sync(
124
+ cl.Message(content="Datastore loading completed.").send()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
+ )
127
+ )
128
 
129
+ # Initialize and compile graph synchronously (blocking as intended)
130
+ if _app_state.compiled_graph is None:
131
+ _app_state.datastore, graph_builder = initialize(_app_state.datastore)
132
+ _app_state.compiled_graph = graph_builder.compile(
133
+ checkpointer=_app_state.checkpointer
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Start datastore population as background task (non-blocking)
137
+ if initialize_datastore:
138
+ asyncio.create_task(
139
+ _app_state.datastore.from_json_globs(
140
+ _app_state.config.transcript_glob
141
  )
142
  )
143
 
144
+
145
+ # def process_response(
146
+ # response_message: BaseMessage,
147
+ # ) -> Tuple[str, List[cl.Message]]:
148
+ # """
149
+ # Processes a response from the AI agents.
150
+
151
+ # Extracts the main text and video references from the response,
152
+ # and creates message elements for displaying video content.
153
+
154
+ # Args:
155
+ # response: Response object from the AI agent
156
+
157
+ # Returns:
158
+ # Tuple containing the text response and a list of message elements with video references
159
+ # """
160
+ # streamed_text = f"[_from: {response_message.name}_]\n"
161
+ # msg_references = []
162
+
163
+ # if response_message.name == VIDEOARCHIVE:
164
+ # text, references = pstuts_rag.rag.RAGChainFactory.unpack_references(
165
+ # str(response_message.content)
166
+ # )
167
+ # streamed_text += text
168
+
169
+ # if len(references) > 0:
170
+ # references = json.loads(references)
171
+ # print(references)
172
+
173
+ # for ref in references:
174
+ # msg_references.append(
175
+ # cl.Message(
176
+ # content=(
177
+ # f"Watch {ref['title']} from timestamp "
178
+ # f"{round(ref['start'] // 60)}m:{round(ref['start'] % 60)}s"
179
+ # ),
180
+ # elements=[
181
+ # cl.Video(
182
+ # name=ref["title"],
183
+ # url=f"{ref['source']}#t={ref['start']}",
184
+ # display="side",
185
+ # )
186
+ # ],
187
+ # )
188
+ # )
189
+ # else:
190
+ # streamed_text += str(response_message.content)
191
+
192
+ # # Find all URLs in the content
193
+ # urls = re.findall(
194
+ # r"https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+[/\w\.-]*(?:\?[/\w\.-=&%]*)?",
195
+ # str(response_message.content),
196
+ # )
197
+ # print(urls)
198
+ # links = []
199
+ # # Create a list of unique URLs
200
+ # for idx, u in enumerate(list(set(urls))):
201
+
202
+ # url = "https://api.microlink.io"
203
+ # params = {
204
+ # "url": u,
205
+ # "screenshot": True,
206
+ # }
207
+
208
+ # payload = requests.get(url, params)
209
+
210
+ # if payload:
211
+ # print(f"Successful screenshot\n{payload.json()}")
212
+ # links.append(
213
+ # cl.Image(
214
+ # name=f"Website {idx} Preview: {u}",
215
+ # display="side", # Show in the sidebar
216
+ # url=payload.json()["data"]["screenshot"]["url"],
217
+ # )
218
+ # )
219
+
220
+ # print(links)
221
+ # msg_references.append(
222
+ # cl.Message(
223
+ # content="\n".join([l.url for l in links]), elements=links
224
+ # )
225
+ # )
226
+
227
+ # return streamed_text, msg_references
228
 
229
 
230
  @cl.on_message
 
238
  Args:
239
  message: User's input message
240
  """
241
+ # for s in app_state.ai_graph.stream(
242
+ # user_cl_message.content, {"recursion_limit": 20}
243
+ # ):
244
+ # if "__end__" not in s and "supervisor" not in s.keys():
245
+ # for [node_type, node_response] in s.items():
246
+ # print(f"Processing {node_type} messages")
247
+ # for node_message in node_response["messages"]:
248
+ # print(f"Message {node_message}")
249
+ # msg = cl.Message(content="")
250
+ # text, references = process_response(node_message)
251
+ # for token in [char for char in text]:
252
+ # await msg.stream_token(token)
253
+ # await msg.send()
254
+ # for m in references:
255
+ # await m.send()
256
 
257
 
258
  if __name__ == "__main__":
chainlit.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to Chainlit! πŸš€πŸ€–
2
+
3
+ Hi there, Developer! πŸ‘‹ We're excited to have you on board. Chainlit is a powerful tool designed to help you prototype, debug and share applications built on top of LLMs.
4
+
5
+ ## Useful Links πŸ”—
6
+
7
+ - **Documentation:** Get started with our comprehensive [Chainlit Documentation](https://docs.chainlit.io) πŸ“š
8
+ - **Discord Community:** Join our friendly [Chainlit Discord](https://discord.gg/k73SQ3FyUh) to ask questions, share your projects, and connect with other developers! πŸ’¬
9
+
10
+ We can't wait to see what you create with Chainlit! Happy coding! πŸ’»πŸ˜Š
11
+
12
+ ## Welcome screen
13
+
14
+ To modify the welcome screen, edit the `chainlit.md` file at the root of your project. If you do not want a welcome screen, just leave this file empty.