mbudisic commited on
Commit
0351317
·
1 Parent(s): e21930a

Full graph works. Now frontend and finetuning

Browse files
notebooks/transcript_rag.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
pstuts_rag/pstuts_rag/datastore.py CHANGED
@@ -231,7 +231,7 @@ class DatastoreManager:
231
  VectorStoreRetriever: The configured retriever
232
  """
233
  return self.vector_store.as_retriever(
234
- search_kwargs={"k": n_context_docs}
235
  )
236
 
237
  def is_ready(self) -> bool:
 
231
  VectorStoreRetriever: The configured retriever
232
  """
233
  return self.vector_store.as_retriever(
234
+ search_kwargs={"k": int(n_context_docs)}
235
  )
236
 
237
  def is_ready(self) -> bool:
pstuts_rag/pstuts_rag/graph.py CHANGED
@@ -118,7 +118,7 @@ def create_agent(
118
 
119
 
120
  def create_tavily_node(
121
- name: str = "AdobeHelp", config: Configuration = Configuration() ) -> Callable
122
  """Initialize tool, agent, and node for Tavily search of helpx.adobe.com.
123
 
124
  This function sets up a search agent that can query Adobe Photoshop help topics
 
118
 
119
 
120
  def create_tavily_node(
121
+ name: str = "AdobeHelp", config: Configuration = Configuration() ) -> Callable:
122
  """Initialize tool, agent, and node for Tavily search of helpx.adobe.com.
123
 
124
  This function sets up a search agent that can query Adobe Photoshop help topics
pstuts_rag/pstuts_rag/nodes.py CHANGED
@@ -1,21 +1,29 @@
1
  # nodes.py
 
 
2
 
 
 
 
3
  from langchain_openai import ChatOpenAI
4
  from langgraph.graph import StateGraph, MessagesState, START, END
5
  from langgraph.types import Command
6
-
7
  from langchain_core.runnables import RunnableConfig
8
- from langchain_core.messages import AnyMessage, HumanMessage
9
  from langgraph.checkpoint.memory import InMemorySaver
10
- from pstuts_rag.prompts import NODE_PROMPTS
11
- from pydantic import BaseModel, Field
 
 
 
12
 
13
 
14
  from pstuts_rag.utils import ChatAPISelector
15
  from pstuts_rag.configuration import Configuration
16
-
17
- from enum import Enum
18
- from typing import Any, Callable, Dict, Literal
19
 
20
 
21
  class TutorialState(MessagesState):
@@ -23,37 +31,97 @@ class TutorialState(MessagesState):
23
 
24
  # next: str
25
  query: str
26
- video_references: set[Any]
27
- url_references: set[Any]
 
 
 
 
 
28
 
29
 
30
  def research(state: TutorialState, config: RunnableConfig):
31
 
32
- # retrieve the LLM
33
- # configurable = Configuration.from_runnable_config(config)
34
- # cls = ChatAPISelector.get(configurable.llm_api, ChatOpenAI)
35
- # llm = cls(model=configurable.llm_tool_model)
36
 
37
- # # format the prompt
38
- # prompt = NODE_PROMPTS["research"]
 
 
 
39
 
40
- # history = [
41
- # msg.content
42
- # for msg in state["messages"]
43
- # if getattr(msg, "role", "") == "ai"
44
- # ]
45
 
46
- # prompt = prompt.format(history=history)
47
 
48
- pass
 
 
 
49
 
50
 
51
- def search_help(state: TutorialState, config: RunnableConfig):
52
- pass
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- def search_rag(state: TutorialState, config: RunnableConfig):
56
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  def join(state: TutorialState, config: RunnableConfig):
@@ -71,6 +139,11 @@ class YesNoDecision(BaseModel):
71
  decision: Literal["yes", "no"] = Field(description="Yes or no decision.")
72
 
73
 
 
 
 
 
 
74
  def route_is_relevant(
75
  state: TutorialState, config: RunnableConfig
76
  ) -> Command[Literal["research", "write_answer"]]:
@@ -82,25 +155,103 @@ def route_is_relevant(
82
  YesNoDecision
83
  )
84
 
 
 
 
 
 
 
 
 
 
 
 
85
  # format the prompt
86
- prompt = NODE_PROMPTS["relevance"].format(query=state["query"])
87
 
88
  relevance = llm.invoke([HumanMessage(content=prompt)])
89
  where = "research" if relevance.decision == "yes" else "write_answer"
90
- answer = f"Query is {'not' if relevance.decision == 'no' else ''} relevant to Photoshop."
 
 
 
91
  return Command(
92
- update={"messages": {"role": "ai", "content": answer}},
93
  goto=where,
94
  )
95
 
96
 
 
 
 
 
 
97
  def route_is_complete(
98
  state: TutorialState, config: RunnableConfig
99
- ) -> Literal["yes", "no"]:
100
- if True:
101
- return "yes"
102
- else:
103
- return "no"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  graph_builder = StateGraph(TutorialState)
@@ -119,17 +270,16 @@ graph_builder.add_node(write_answer)
119
  # {"yes": research.__name__, "no": write_answer.__name__},
120
  # )
121
  graph_builder.add_node(route_is_relevant)
 
 
122
  graph_builder.add_edge(START, route_is_relevant.__name__)
123
  graph_builder.add_edge(research.__name__, search_help.__name__)
124
  graph_builder.add_edge(research.__name__, search_rag.__name__)
125
- graph_builder.add_edge(search_help.__name__, join.__name__)
126
- graph_builder.add_edge(search_rag.__name__, join.__name__)
127
- graph_builder.add_conditional_edges(
128
- join.__name__,
129
- route_is_complete,
130
- {"no": research.__name__, "yes": write_answer.__name__},
131
- )
132
  graph_builder.add_edge(write_answer.__name__, END)
133
 
134
 
135
  graph = graph_builder.compile()
 
 
1
  # nodes.py
2
+ from enum import Enum
3
+ from typing import Annotated, Any, Callable, Dict, Literal
4
 
5
+ import asyncio
6
+ import logging
7
+ import operator
8
  from langchain_openai import ChatOpenAI
9
  from langgraph.graph import StateGraph, MessagesState, START, END
10
  from langgraph.types import Command
11
+ from langchain_core.documents import Document
12
  from langchain_core.runnables import RunnableConfig
13
+ from langchain_core.messages import HumanMessage, AIMessage
14
  from langgraph.checkpoint.memory import InMemorySaver
15
+ from numpy import add
16
+ from langchain_community.tools.tavily_search import TavilySearchResults
17
+ from langchain_tavily import TavilyExtract
18
+
19
+ from pydantic import BaseModel, Field, HttpUrl
20
 
21
 
22
  from pstuts_rag.utils import ChatAPISelector
23
  from pstuts_rag.configuration import Configuration
24
+ from pstuts_rag.datastore import DatastoreManager
25
+ from pstuts_rag.prompts import NODE_PROMPTS
26
+ from pstuts_rag.rag_for_transcripts import create_transcript_rag_chain
27
 
28
 
29
  class TutorialState(MessagesState):
 
31
 
32
  # next: str
33
  query: str
34
+ video_references: Annotated[list[Document], operator.add]
35
+ url_references: Annotated[list[Dict], operator.add]
36
+ loop_count: int
37
+
38
+
39
+ datastore = DatastoreManager()
40
+ datastore.add_completion_callback(lambda: logging.warning("Loading complete."))
41
 
42
 
43
  def research(state: TutorialState, config: RunnableConfig):
44
 
45
+ configurable = Configuration.from_runnable_config(config)
46
+ cls = ChatAPISelector.get(configurable.llm_api, ChatOpenAI)
47
+ llm = cls(model=configurable.llm_tool_model, temperature=0)
 
48
 
49
+ history = [
50
+ msg.content
51
+ for msg in state["messages"]
52
+ if getattr(msg, "role", "") == "ai"
53
+ ]
54
 
55
+ prompt = NODE_PROMPTS["research"].format(
56
+ history=history, query=state["query"]
57
+ )
 
 
58
 
59
+ search_query = llm.invoke([HumanMessage(content=prompt)])
60
 
61
+ return {
62
+ "messages": [search_query],
63
+ "loop_count": state.get("loop_count", 0) + 1,
64
+ }
65
 
66
 
67
+ async def search_help(
68
+ state: TutorialState, config: RunnableConfig | None = None
69
+ ):
70
+
71
+ configurable = (
72
+ Configuration()
73
+ if not config
74
+ else Configuration.from_runnable_config(config)
75
+ )
76
 
77
+ cls = ChatAPISelector.get(configurable.llm_api, ChatOpenAI)
78
+ llm = cls(model=configurable.llm_tool_model, temperature=0)
79
+ prompt = NODE_PROMPTS["search_summary"]
80
+
81
+ adobe_help_search = TavilySearchResults(
82
+ max_results=2,
83
+ include_domains=["helpx.adobe.com"],
84
+ include_answer=True,
85
+ include_raw_content=True,
86
+ include_images=True,
87
+ response_format="content_and_artifact", # Always returns artifacts
88
+ )
89
+ query = state["messages"][-1].content
90
+ results = await adobe_help_search.ainvoke(query)
91
 
92
+ urls = list(r["url"] for r in results)
93
+ tool = TavilyExtract(
94
+ extract_depth="basic",
95
+ include_images=False,
96
+ )
97
+
98
+ results = await tool.ainvoke({"urls": urls})
99
+
100
+ if "results" in results:
101
+ all_text = list(r["raw_content"] for r in results["results"])
102
+ else:
103
+ all_text = []
104
+
105
+ prompt = prompt.format(
106
+ query=query,
107
+ text="\n***\n".join(all_text),
108
+ )
109
+
110
+ url_summary = await llm.ainvoke([HumanMessage(content=prompt)])
111
+
112
+ return {"messages": [url_summary], "url_references": results["results"]}
113
+
114
+
115
+ async def search_rag(state: TutorialState, config: RunnableConfig):
116
+
117
+ chain = create_transcript_rag_chain(datastore, config)
118
+
119
+ response = await chain.ainvoke({"question": state["messages"][-1].content})
120
+
121
+ return {
122
+ "messages": [response],
123
+ "video_references": response.additional_kwargs["context"],
124
+ }
125
 
126
 
127
  def join(state: TutorialState, config: RunnableConfig):
 
139
  decision: Literal["yes", "no"] = Field(description="Yes or no decision.")
140
 
141
 
142
+ class URLReference(BaseModel):
143
+ summary: str
144
+ url: HttpUrl
145
+
146
+
147
  def route_is_relevant(
148
  state: TutorialState, config: RunnableConfig
149
  ) -> Command[Literal["research", "write_answer"]]:
 
155
  YesNoDecision
156
  )
157
 
158
+ human_messages = [
159
+ msg.content
160
+ for msg in state["messages"]
161
+ if isinstance(msg, HumanMessage)
162
+ ]
163
+
164
+ if len(human_messages) > 0:
165
+ query = human_messages[-1]
166
+ else:
167
+ query = state["query"]
168
+
169
  # format the prompt
170
+ prompt = NODE_PROMPTS["relevance"].format(query=query)
171
 
172
  relevance = llm.invoke([HumanMessage(content=prompt)])
173
  where = "research" if relevance.decision == "yes" else "write_answer"
174
+ answer = (
175
+ f"Query is {'not' if relevance.decision == 'no' else ''} "
176
+ "relevant to Photoshop."
177
+ )
178
  return Command(
179
+ update={"messages": [AIMessage(content=answer)], "query": query},
180
  goto=where,
181
  )
182
 
183
 
184
+ class IsComplete(BaseModel):
185
+ decision: Literal["yes", "no"] = Field(description="Yes or no decision.")
186
+ new_query: str = Field(description="Query for additional research.")
187
+
188
+
189
  def route_is_complete(
190
  state: TutorialState, config: RunnableConfig
191
+ ) -> Command[Literal["research", "write_answer"]]:
192
+
193
+ # retrieve the LLM
194
+ configurable = Configuration.from_runnable_config(config)
195
+
196
+ if state["loop_count"] >= int(configurable.max_research_loops):
197
+ return Command(
198
+ update={
199
+ "messages": [
200
+ AIMessage(
201
+ content="Research loop count is too large. Do your best with what you have."
202
+ )
203
+ ]
204
+ },
205
+ goto="write_answer",
206
+ )
207
+
208
+ cls = ChatAPISelector.get(configurable.llm_api, ChatOpenAI)
209
+ llm = cls(model=configurable.llm_tool_model).with_structured_output(
210
+ YesNoDecision
211
+ )
212
+
213
+ ai_messages = list(
214
+ msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
215
+ )
216
+
217
+ # format the prompt
218
+ prompt = NODE_PROMPTS["completeness"].format(
219
+ query=state["query"], responses="\n\n".join(ai_messages)
220
+ )
221
+
222
+ completeness = llm.invoke([HumanMessage(content=prompt)])
223
+ where = "write_answer" if "yes" in completeness.decision else "research"
224
+
225
+ # Convert YesNoDecision to AIMessage
226
+ decision_message = AIMessage(
227
+ content=f"Research completeness: {completeness.decision}"
228
+ )
229
+
230
+ return Command(
231
+ update={"messages": [decision_message]},
232
+ goto=where,
233
+ )
234
+
235
+
236
+ def write_answer(state: TutorialState, config: RunnableConfig):
237
+
238
+ # retrieve the LLM
239
+ configurable = Configuration.from_runnable_config(config)
240
+ cls = ChatAPISelector.get(configurable.llm_api, ChatOpenAI)
241
+ llm = cls(model=configurable.llm_tool_model)
242
+
243
+ ai_messages = list(
244
+ msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
245
+ )
246
+
247
+ # format the prompt
248
+ prompt = NODE_PROMPTS["final_answer"].format(
249
+ query=state["query"], responses="\n\n".join(ai_messages)
250
+ )
251
+
252
+ final_answer = llm.invoke([HumanMessage(content=prompt)])
253
+
254
+ return {"messages": [final_answer]}
255
 
256
 
257
  graph_builder = StateGraph(TutorialState)
 
270
  # {"yes": research.__name__, "no": write_answer.__name__},
271
  # )
272
  graph_builder.add_node(route_is_relevant)
273
+ graph_builder.add_node(route_is_complete, defer=True)
274
+
275
  graph_builder.add_edge(START, route_is_relevant.__name__)
276
  graph_builder.add_edge(research.__name__, search_help.__name__)
277
  graph_builder.add_edge(research.__name__, search_rag.__name__)
278
+ graph_builder.add_edge(search_help.__name__, route_is_complete.__name__)
279
+ graph_builder.add_edge(search_rag.__name__, route_is_complete.__name__)
280
+
 
 
 
 
281
  graph_builder.add_edge(write_answer.__name__, END)
282
 
283
 
284
  graph = graph_builder.compile()
285
+ asyncio.run(datastore.from_json_globs(Configuration().transcript_glob))
pstuts_rag/pstuts_rag/prompts.py CHANGED
@@ -164,3 +164,67 @@ is relevant to Adobe Photoshop, otherwise no.
164
 
165
  Relevant?
166
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  Relevant?
166
  """
167
+
168
+ NODE_PROMPTS[
169
+ "search_summary"
170
+ ] = """
171
+ <QUERY>
172
+ {query}
173
+ </QUERY>
174
+ <WEBSITE_TEXT>
175
+ {text}
176
+ </WEBSITE_TEXT>
177
+
178
+ <TASK>
179
+ Use WEBSITE_TEXT to produce a summarized
180
+ answer to the QUERY.
181
+
182
+ Aim for the audience at a level of an advanced high school student.
183
+ Do not invent material that is not in the text.
184
+
185
+ Your output should be at most 200 words long.
186
+ </TASK>
187
+ """
188
+
189
+ NODE_PROMPTS[
190
+ "completeness"
191
+ ] = """
192
+ <QUERY>
193
+ {query}
194
+ </QUERY>
195
+ <RESEARCH>
196
+ {responses}
197
+ </RESEARCH>
198
+
199
+ <TASK>
200
+ Your goal is to evaluate if RESEARCH is sufficiently detailed to provide a comprehensive
201
+ and clear answer for QUERY.
202
+
203
+ If the RESEARCH is sufficiently complete, state "yes" as your decision.
204
+
205
+ If new terms were introduced in RESEARCH that are not sufficiently explained,
206
+ or the QUERY is not sufficiently addressed, response as "no".
207
+ </TASK>
208
+
209
+ <FINAL_CHECK>
210
+ Your response must be either "yes" or "no".
211
+ </FINAL_CHECK>
212
+ """
213
+
214
+ NODE_PROMPTS[
215
+ "final_answer"
216
+ ] = """
217
+ <QUERY>
218
+ {query}
219
+ </QUERY>
220
+ <RESEARCH>
221
+ {responses}
222
+ </RESEARCH>
223
+
224
+ <TASK>
225
+ Use the content in RESEARCH to provide a detailed answer to the QUERY.
226
+ Do not add the material, fully ground yourself in the research context.
227
+
228
+ End your response with "I hope you're happy!".
229
+ </TASK>
230
+ """
pstuts_rag/pstuts_rag/rag_for_transcripts.py CHANGED
@@ -58,10 +58,10 @@ def post_process_response(
58
  else answer.content
59
  )
60
  # Only append references if the model provided a substantive answer
61
- if "I don't know" not in answer.content:
62
- text_w_references = "\n".join(
63
- [str(text_w_references), "**REFERENCES**", references]
64
- )
65
 
66
  # Create new message with references and preserve original context metadata
67
  output: AIMessage = answer.model_copy(
 
58
  else answer.content
59
  )
60
  # Only append references if the model provided a substantive answer
61
+ # if "I don't know" not in answer.content:
62
+ # text_w_references = "\n".join(
63
+ # [str(text_w_references), "**REFERENCES**", references]
64
+ # )
65
 
66
  # Create new message with references and preserve original context metadata
67
  output: AIMessage = answer.model_copy(
pyproject.toml CHANGED
@@ -49,6 +49,7 @@ dependencies = [
49
  "langchain-ollama>=0.3.2",
50
  "simsimd>=6.2.1",
51
  "langgraph-cli[inmem]>=0.1.55",
 
52
  ]
53
  authors = [{ name = "Marko Budisic", email = "[email protected]" }]
54
  license = "MIT"
 
49
  "langchain-ollama>=0.3.2",
50
  "simsimd>=6.2.1",
51
  "langgraph-cli[inmem]>=0.1.55",
52
+ "langchain-tavily>=0.2.0",
53
  ]
54
  authors = [{ name = "Marko Budisic", email = "[email protected]" }]
55
  license = "MIT"
temp_function.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def enter_chain(message: str):
2
+ results = {
3
+ "messages": [HumanMessage(content=message)],
4
+ "team_members": ["VideoArchiveSearch", "AdobeHelp"],
5
+ }
6
+ return results
uv.lock CHANGED
@@ -1879,6 +1879,22 @@ wheels = [
1879
  { url = "https://files.pythonhosted.org/packages/68/01/22dad84373ba282237a3351547443c9c94c39fe75f71a1759f97cfa89725/langchain_qdrant-0.2.0-py3-none-any.whl", hash = "sha256:8eab5b8a553204ddb809d8183a6f1bc12fc265688592d9d897388f6939c79bf8", size = 23406 },
1880
  ]
1881
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1882
  [[package]]
1883
  name = "langchain-text-splitters"
1884
  version = "0.3.8"
@@ -3741,6 +3757,7 @@ dependencies = [
3741
  { name = "langchain-ollama" },
3742
  { name = "langchain-openai" },
3743
  { name = "langchain-qdrant" },
 
3744
  { name = "langgraph" },
3745
  { name = "langgraph-cli", extra = ["inmem"] },
3746
  { name = "langsmith" },
@@ -3810,6 +3827,7 @@ requires-dist = [
3810
  { name = "langchain-ollama", specifier = ">=0.3.2" },
3811
  { name = "langchain-openai" },
3812
  { name = "langchain-qdrant", specifier = ">=0.2.0" },
 
3813
  { name = "langgraph", specifier = ">=0.2.55" },
3814
  { name = "langgraph-cli", extras = ["inmem"], specifier = ">=0.1.55" },
3815
  { name = "langsmith", specifier = ">=0.0.50" },
 
1879
  { url = "https://files.pythonhosted.org/packages/68/01/22dad84373ba282237a3351547443c9c94c39fe75f71a1759f97cfa89725/langchain_qdrant-0.2.0-py3-none-any.whl", hash = "sha256:8eab5b8a553204ddb809d8183a6f1bc12fc265688592d9d897388f6939c79bf8", size = 23406 },
1880
  ]
1881
 
1882
+ [[package]]
1883
+ name = "langchain-tavily"
1884
+ version = "0.2.0"
1885
+ source = { registry = "https://pypi.org/simple" }
1886
+ dependencies = [
1887
+ { name = "aiohttp" },
1888
+ { name = "langchain" },
1889
+ { name = "langchain-core" },
1890
+ { name = "mypy" },
1891
+ { name = "requests" },
1892
+ ]
1893
+ sdist = { url = "https://files.pythonhosted.org/packages/df/63/e7c41f837914806b3c255c4c46d0948528101279656a523b7e11be740e06/langchain_tavily-0.2.0.tar.gz", hash = "sha256:b400525d6d2c28902d2acb25af28751aa1a9a1f99c7880eea4d701f3993736fb", size = 19813 }
1894
+ wheels = [
1895
+ { url = "https://files.pythonhosted.org/packages/b5/a7/2e59086df6006ac09a8d8d8f43683ff2f84608d69984bf1593c92faeefb0/langchain_tavily-0.2.0-py3-none-any.whl", hash = "sha256:a5b780f96c80d5a3e7c933da2d603cb26ba94b10f7c1ac4b89ce5b123c7541b4", size = 23580 },
1896
+ ]
1897
+
1898
  [[package]]
1899
  name = "langchain-text-splitters"
1900
  version = "0.3.8"
 
3757
  { name = "langchain-ollama" },
3758
  { name = "langchain-openai" },
3759
  { name = "langchain-qdrant" },
3760
+ { name = "langchain-tavily" },
3761
  { name = "langgraph" },
3762
  { name = "langgraph-cli", extra = ["inmem"] },
3763
  { name = "langsmith" },
 
3827
  { name = "langchain-ollama", specifier = ">=0.3.2" },
3828
  { name = "langchain-openai" },
3829
  { name = "langchain-qdrant", specifier = ">=0.2.0" },
3830
+ { name = "langchain-tavily", specifier = ">=0.2.0" },
3831
  { name = "langgraph", specifier = ">=0.2.55" },
3832
  { name = "langgraph-cli", extras = ["inmem"], specifier = ">=0.1.55" },
3833
  { name = "langsmith", specifier = ">=0.0.50" },