Spaces:
Sleeping
Sleeping
refactor: Enhance routing logic in tutorial state management
Browse files- Updated the `next` variable in `TutorialState` to clarify its purpose as a routing variable for the next node.
- Refactored the `search_help` and `route_is_complete` functions to return a dictionary instead of a Command object, improving consistency in state updates.
- Enhanced the decision-making logic in `search_help` to include a broader range of affirmative responses.
- Added conditional edges in the graph builder to route based on the `next` state, improving workflow orchestration.
- pstuts_rag/pstuts_rag/nodes.py +55 -50
pstuts_rag/pstuts_rag/nodes.py
CHANGED
|
@@ -65,7 +65,7 @@ class YesNoAsk(Enum):
|
|
| 65 |
class TutorialState(MessagesState):
|
| 66 |
"""State management for tutorial team workflow orchestration."""
|
| 67 |
|
| 68 |
-
# next
|
| 69 |
query: str
|
| 70 |
video_references: Annotated[list[Document], operator.add]
|
| 71 |
url_references: Annotated[list[Dict], operator.add]
|
|
@@ -196,8 +196,24 @@ async def search_help(state: TutorialState, config: RunnableConfig):
|
|
| 196 |
)
|
| 197 |
|
| 198 |
logging.info(f"Permission response '{response}'")
|
| 199 |
-
decision =
|
| 200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
response = {
|
| 203 |
"messages": [],
|
|
@@ -309,7 +325,7 @@ class URLReference(BaseModel):
|
|
| 309 |
def route_is_relevant(
|
| 310 |
state: TutorialState,
|
| 311 |
config: RunnableConfig,
|
| 312 |
-
)
|
| 313 |
"""Route based on whether the query is relevant to Photoshop tutorials.
|
| 314 |
|
| 315 |
Args:
|
|
@@ -317,10 +333,8 @@ def route_is_relevant(
|
|
| 317 |
config: RunnableConfig for accessing configuration parameters
|
| 318 |
|
| 319 |
Returns:
|
| 320 |
-
|
| 321 |
"""
|
| 322 |
-
|
| 323 |
-
# retrieve the LLM
|
| 324 |
configurable = Configuration.from_runnable_config(config)
|
| 325 |
cls = get_chat_api(configurable.llm_api)
|
| 326 |
logging.info("LLM SELECTED: %s", cls)
|
|
@@ -340,19 +354,18 @@ def route_is_relevant(
|
|
| 340 |
else:
|
| 341 |
query = state["query"]
|
| 342 |
|
| 343 |
-
# format the prompt
|
| 344 |
prompt = NODE_PROMPTS["relevance"].format(query=query)
|
| 345 |
-
|
| 346 |
relevance = llm.invoke([HumanMessage(content=prompt)])
|
| 347 |
-
|
| 348 |
answer = (
|
| 349 |
f"Query is {'not' if relevance.decision == 'no' else ''} "
|
| 350 |
"relevant to Photoshop."
|
| 351 |
)
|
| 352 |
-
return
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
|
|
|
| 356 |
|
| 357 |
|
| 358 |
class IsComplete(BaseModel):
|
|
@@ -367,9 +380,7 @@ class IsComplete(BaseModel):
|
|
| 367 |
new_query: str = Field(description="Query for additional research.")
|
| 368 |
|
| 369 |
|
| 370 |
-
def route_is_complete(
|
| 371 |
-
state: TutorialState, config: RunnableConfig
|
| 372 |
-
) -> Command[Literal["research", "write_answer"]]:
|
| 373 |
"""Route based on whether research is complete or more is needed.
|
| 374 |
|
| 375 |
Args:
|
|
@@ -377,23 +388,19 @@ def route_is_complete(
|
|
| 377 |
config: RunnableConfig for accessing configuration parameters
|
| 378 |
|
| 379 |
Returns:
|
| 380 |
-
|
| 381 |
"""
|
| 382 |
-
|
| 383 |
-
# retrieve the LLM
|
| 384 |
configurable = Configuration.from_runnable_config(config)
|
| 385 |
|
| 386 |
if state["loop_count"] >= int(configurable.max_research_loops):
|
| 387 |
-
return
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
goto="write_answer",
|
| 396 |
-
)
|
| 397 |
|
| 398 |
cls = get_chat_api(configurable.llm_api)
|
| 399 |
logging.info("LLM SELECTED: %s", cls)
|
|
@@ -406,23 +413,19 @@ def route_is_complete(
|
|
| 406 |
msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
|
| 407 |
)
|
| 408 |
|
| 409 |
-
# format the prompt
|
| 410 |
prompt = NODE_PROMPTS["completeness"].format(
|
| 411 |
-
query=state["query"],
|
|
|
|
| 412 |
)
|
| 413 |
|
| 414 |
completeness = llm.invoke([HumanMessage(content=prompt)])
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
decision_message = AIMessage(
|
| 419 |
content=f"Research completeness: {completeness.decision}"
|
| 420 |
)
|
| 421 |
-
|
| 422 |
-
return Command(
|
| 423 |
-
update={"messages": [decision_message]},
|
| 424 |
-
goto=where,
|
| 425 |
-
)
|
| 426 |
|
| 427 |
|
| 428 |
def write_answer(state: TutorialState, config: RunnableConfig):
|
|
@@ -511,8 +514,6 @@ def initialize(
|
|
| 511 |
|
| 512 |
graph_builder = StateGraph(TutorialState)
|
| 513 |
|
| 514 |
-
# graph_builder.add_node(route_is_relevant)
|
| 515 |
-
# graph_builder.add_node(route_is_complete, defer=True)
|
| 516 |
graph_builder.add_node(init_state)
|
| 517 |
graph_builder.add_node(research)
|
| 518 |
graph_builder.add_node(search_help)
|
|
@@ -520,24 +521,28 @@ def initialize(
|
|
| 520 |
"search_rag", functools.partial(search_rag, datastore=datastore)
|
| 521 |
)
|
| 522 |
graph_builder.add_node(write_answer)
|
| 523 |
-
|
| 524 |
-
# graph_builder.add_conditional_edges(
|
| 525 |
-
# START,
|
| 526 |
-
# route_is_relevant,
|
| 527 |
-
# {"yes": research.__name__, "no": write_answer.__name__},
|
| 528 |
-
# )
|
| 529 |
graph_builder.add_node(route_is_relevant)
|
| 530 |
graph_builder.add_node(route_is_complete, defer=True)
|
| 531 |
graph_builder.add_edge(START, init_state.__name__)
|
| 532 |
-
|
| 533 |
graph_builder.add_edge(init_state.__name__, route_is_relevant.__name__)
|
| 534 |
graph_builder.add_edge(research.__name__, search_help.__name__)
|
| 535 |
graph_builder.add_edge(research.__name__, search_rag.__name__)
|
| 536 |
graph_builder.add_edge(search_help.__name__, route_is_complete.__name__)
|
| 537 |
graph_builder.add_edge(search_rag.__name__, route_is_complete.__name__)
|
| 538 |
-
|
| 539 |
graph_builder.add_edge(write_answer.__name__, END)
|
| 540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 541 |
return datastore, graph_builder
|
| 542 |
|
| 543 |
|
|
|
|
| 65 |
class TutorialState(MessagesState):
|
| 66 |
"""State management for tutorial team workflow orchestration."""
|
| 67 |
|
| 68 |
+
next: str # Routing variable for next node
|
| 69 |
query: str
|
| 70 |
video_references: Annotated[list[Document], operator.add]
|
| 71 |
url_references: Annotated[list[Dict], operator.add]
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
logging.info(f"Permission response '{response}'")
|
| 199 |
+
decision = (
|
| 200 |
+
YesNoAsk.YES
|
| 201 |
+
if any(
|
| 202 |
+
affirmative in response.strip().lower()
|
| 203 |
+
for affirmative in [
|
| 204 |
+
"yes",
|
| 205 |
+
"true",
|
| 206 |
+
"ok",
|
| 207 |
+
"okay",
|
| 208 |
+
"1",
|
| 209 |
+
"y",
|
| 210 |
+
"sure",
|
| 211 |
+
"fine",
|
| 212 |
+
"alright",
|
| 213 |
+
]
|
| 214 |
+
)
|
| 215 |
+
else YesNoAsk.NO
|
| 216 |
+
)
|
| 217 |
|
| 218 |
response = {
|
| 219 |
"messages": [],
|
|
|
|
| 325 |
def route_is_relevant(
|
| 326 |
state: TutorialState,
|
| 327 |
config: RunnableConfig,
|
| 328 |
+
):
|
| 329 |
"""Route based on whether the query is relevant to Photoshop tutorials.
|
| 330 |
|
| 331 |
Args:
|
|
|
|
| 333 |
config: RunnableConfig for accessing configuration parameters
|
| 334 |
|
| 335 |
Returns:
|
| 336 |
+
dict: State update with 'next' set to the next node name
|
| 337 |
"""
|
|
|
|
|
|
|
| 338 |
configurable = Configuration.from_runnable_config(config)
|
| 339 |
cls = get_chat_api(configurable.llm_api)
|
| 340 |
logging.info("LLM SELECTED: %s", cls)
|
|
|
|
| 354 |
else:
|
| 355 |
query = state["query"]
|
| 356 |
|
|
|
|
| 357 |
prompt = NODE_PROMPTS["relevance"].format(query=query)
|
|
|
|
| 358 |
relevance = llm.invoke([HumanMessage(content=prompt)])
|
| 359 |
+
next_node = "research" if relevance.decision == "yes" else "write_answer"
|
| 360 |
answer = (
|
| 361 |
f"Query is {'not' if relevance.decision == 'no' else ''} "
|
| 362 |
"relevant to Photoshop."
|
| 363 |
)
|
| 364 |
+
return {
|
| 365 |
+
"messages": [AIMessage(content=answer)],
|
| 366 |
+
"query": query,
|
| 367 |
+
"next": next_node,
|
| 368 |
+
}
|
| 369 |
|
| 370 |
|
| 371 |
class IsComplete(BaseModel):
|
|
|
|
| 380 |
new_query: str = Field(description="Query for additional research.")
|
| 381 |
|
| 382 |
|
| 383 |
+
def route_is_complete(state: TutorialState, config: RunnableConfig):
|
|
|
|
|
|
|
| 384 |
"""Route based on whether research is complete or more is needed.
|
| 385 |
|
| 386 |
Args:
|
|
|
|
| 388 |
config: RunnableConfig for accessing configuration parameters
|
| 389 |
|
| 390 |
Returns:
|
| 391 |
+
dict: State update with 'next' set to the next node name
|
| 392 |
"""
|
|
|
|
|
|
|
| 393 |
configurable = Configuration.from_runnable_config(config)
|
| 394 |
|
| 395 |
if state["loop_count"] >= int(configurable.max_research_loops):
|
| 396 |
+
return {
|
| 397 |
+
"messages": [
|
| 398 |
+
AIMessage(
|
| 399 |
+
content="Research loop count is too large. Do your best with what you have."
|
| 400 |
+
)
|
| 401 |
+
],
|
| 402 |
+
"next": "write_answer",
|
| 403 |
+
}
|
|
|
|
|
|
|
| 404 |
|
| 405 |
cls = get_chat_api(configurable.llm_api)
|
| 406 |
logging.info("LLM SELECTED: %s", cls)
|
|
|
|
| 413 |
msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
|
| 414 |
)
|
| 415 |
|
|
|
|
| 416 |
prompt = NODE_PROMPTS["completeness"].format(
|
| 417 |
+
query=state["query"],
|
| 418 |
+
responses="\n\n".join(str(m) for m in ai_messages),
|
| 419 |
)
|
| 420 |
|
| 421 |
completeness = llm.invoke([HumanMessage(content=prompt)])
|
| 422 |
+
next_node = (
|
| 423 |
+
"write_answer" if "yes" in completeness.decision else "research"
|
| 424 |
+
)
|
| 425 |
decision_message = AIMessage(
|
| 426 |
content=f"Research completeness: {completeness.decision}"
|
| 427 |
)
|
| 428 |
+
return {"messages": [decision_message], "next": next_node}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
|
| 431 |
def write_answer(state: TutorialState, config: RunnableConfig):
|
|
|
|
| 514 |
|
| 515 |
graph_builder = StateGraph(TutorialState)
|
| 516 |
|
|
|
|
|
|
|
| 517 |
graph_builder.add_node(init_state)
|
| 518 |
graph_builder.add_node(research)
|
| 519 |
graph_builder.add_node(search_help)
|
|
|
|
| 521 |
"search_rag", functools.partial(search_rag, datastore=datastore)
|
| 522 |
)
|
| 523 |
graph_builder.add_node(write_answer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
graph_builder.add_node(route_is_relevant)
|
| 525 |
graph_builder.add_node(route_is_complete, defer=True)
|
| 526 |
graph_builder.add_edge(START, init_state.__name__)
|
|
|
|
| 527 |
graph_builder.add_edge(init_state.__name__, route_is_relevant.__name__)
|
| 528 |
graph_builder.add_edge(research.__name__, search_help.__name__)
|
| 529 |
graph_builder.add_edge(research.__name__, search_rag.__name__)
|
| 530 |
graph_builder.add_edge(search_help.__name__, route_is_complete.__name__)
|
| 531 |
graph_builder.add_edge(search_rag.__name__, route_is_complete.__name__)
|
|
|
|
| 532 |
graph_builder.add_edge(write_answer.__name__, END)
|
| 533 |
|
| 534 |
+
# Conditional edges for routing based on 'next' in state
|
| 535 |
+
graph_builder.add_conditional_edges(
|
| 536 |
+
route_is_relevant.__name__,
|
| 537 |
+
lambda state: state["next"],
|
| 538 |
+
{"research": research.__name__, "write_answer": write_answer.__name__},
|
| 539 |
+
)
|
| 540 |
+
graph_builder.add_conditional_edges(
|
| 541 |
+
route_is_complete.__name__,
|
| 542 |
+
lambda state: state["next"],
|
| 543 |
+
{"research": research.__name__, "write_answer": write_answer.__name__},
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
return datastore, graph_builder
|
| 547 |
|
| 548 |
|