mbudisic commited on
Commit
b30f052
·
1 Parent(s): e86ca95

refactor: Enhance routing logic in tutorial state management

Browse files

- Updated the `next` variable in `TutorialState` to clarify its purpose as a routing variable for the next node.
- Refactored the `search_help` and `route_is_complete` functions to return a dictionary instead of a Command object, improving consistency in state updates.
- Enhanced the decision-making logic in `search_help` to include a broader range of affirmative responses.
- Added conditional edges in the graph builder to route based on the `next` state, improving workflow orchestration.

Files changed (1) hide show
  1. pstuts_rag/pstuts_rag/nodes.py +55 -50
pstuts_rag/pstuts_rag/nodes.py CHANGED
@@ -65,7 +65,7 @@ class YesNoAsk(Enum):
65
  class TutorialState(MessagesState):
66
  """State management for tutorial team workflow orchestration."""
67
 
68
- # next: str
69
  query: str
70
  video_references: Annotated[list[Document], operator.add]
71
  url_references: Annotated[list[Dict], operator.add]
@@ -196,8 +196,24 @@ async def search_help(state: TutorialState, config: RunnableConfig):
196
  )
197
 
198
  logging.info(f"Permission response '{response}'")
199
- decision = YesNoAsk.YES if "yes" in response.strip() else YesNoAsk.NO
200
- return {"search_permission": decision}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  response = {
203
  "messages": [],
@@ -309,7 +325,7 @@ class URLReference(BaseModel):
309
  def route_is_relevant(
310
  state: TutorialState,
311
  config: RunnableConfig,
312
- ) -> Command[Literal["research", "write_answer"]]:
313
  """Route based on whether the query is relevant to Photoshop tutorials.
314
 
315
  Args:
@@ -317,10 +333,8 @@ def route_is_relevant(
317
  config: RunnableConfig for accessing configuration parameters
318
 
319
  Returns:
320
- Command: Navigation command to either 'research' or 'write_answer'
321
  """
322
-
323
- # retrieve the LLM
324
  configurable = Configuration.from_runnable_config(config)
325
  cls = get_chat_api(configurable.llm_api)
326
  logging.info("LLM SELECTED: %s", cls)
@@ -340,19 +354,18 @@ def route_is_relevant(
340
  else:
341
  query = state["query"]
342
 
343
- # format the prompt
344
  prompt = NODE_PROMPTS["relevance"].format(query=query)
345
-
346
  relevance = llm.invoke([HumanMessage(content=prompt)])
347
- where = "research" if relevance.decision == "yes" else "write_answer"
348
  answer = (
349
  f"Query is {'not' if relevance.decision == 'no' else ''} "
350
  "relevant to Photoshop."
351
  )
352
- return Command(
353
- update={"messages": [AIMessage(content=answer)], "query": query},
354
- goto=where,
355
- )
 
356
 
357
 
358
  class IsComplete(BaseModel):
@@ -367,9 +380,7 @@ class IsComplete(BaseModel):
367
  new_query: str = Field(description="Query for additional research.")
368
 
369
 
370
- def route_is_complete(
371
- state: TutorialState, config: RunnableConfig
372
- ) -> Command[Literal["research", "write_answer"]]:
373
  """Route based on whether research is complete or more is needed.
374
 
375
  Args:
@@ -377,23 +388,19 @@ def route_is_complete(
377
  config: RunnableConfig for accessing configuration parameters
378
 
379
  Returns:
380
- Command: Navigation command to either 'research' or 'write_answer'
381
  """
382
-
383
- # retrieve the LLM
384
  configurable = Configuration.from_runnable_config(config)
385
 
386
  if state["loop_count"] >= int(configurable.max_research_loops):
387
- return Command(
388
- update={
389
- "messages": [
390
- AIMessage(
391
- content="Research loop count is too large. Do your best with what you have."
392
- )
393
- ]
394
- },
395
- goto="write_answer",
396
- )
397
 
398
  cls = get_chat_api(configurable.llm_api)
399
  logging.info("LLM SELECTED: %s", cls)
@@ -406,23 +413,19 @@ def route_is_complete(
406
  msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
407
  )
408
 
409
- # format the prompt
410
  prompt = NODE_PROMPTS["completeness"].format(
411
- query=state["query"], responses="\n\n".join(ai_messages)
 
412
  )
413
 
414
  completeness = llm.invoke([HumanMessage(content=prompt)])
415
- where = "write_answer" if "yes" in completeness.decision else "research"
416
-
417
- # Convert YesNoDecision to AIMessage
418
  decision_message = AIMessage(
419
  content=f"Research completeness: {completeness.decision}"
420
  )
421
-
422
- return Command(
423
- update={"messages": [decision_message]},
424
- goto=where,
425
- )
426
 
427
 
428
  def write_answer(state: TutorialState, config: RunnableConfig):
@@ -511,8 +514,6 @@ def initialize(
511
 
512
  graph_builder = StateGraph(TutorialState)
513
 
514
- # graph_builder.add_node(route_is_relevant)
515
- # graph_builder.add_node(route_is_complete, defer=True)
516
  graph_builder.add_node(init_state)
517
  graph_builder.add_node(research)
518
  graph_builder.add_node(search_help)
@@ -520,24 +521,28 @@ def initialize(
520
  "search_rag", functools.partial(search_rag, datastore=datastore)
521
  )
522
  graph_builder.add_node(write_answer)
523
-
524
- # graph_builder.add_conditional_edges(
525
- # START,
526
- # route_is_relevant,
527
- # {"yes": research.__name__, "no": write_answer.__name__},
528
- # )
529
  graph_builder.add_node(route_is_relevant)
530
  graph_builder.add_node(route_is_complete, defer=True)
531
  graph_builder.add_edge(START, init_state.__name__)
532
-
533
  graph_builder.add_edge(init_state.__name__, route_is_relevant.__name__)
534
  graph_builder.add_edge(research.__name__, search_help.__name__)
535
  graph_builder.add_edge(research.__name__, search_rag.__name__)
536
  graph_builder.add_edge(search_help.__name__, route_is_complete.__name__)
537
  graph_builder.add_edge(search_rag.__name__, route_is_complete.__name__)
538
-
539
  graph_builder.add_edge(write_answer.__name__, END)
540
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  return datastore, graph_builder
542
 
543
 
 
65
  class TutorialState(MessagesState):
66
  """State management for tutorial team workflow orchestration."""
67
 
68
+ next: str # Routing variable for next node
69
  query: str
70
  video_references: Annotated[list[Document], operator.add]
71
  url_references: Annotated[list[Dict], operator.add]
 
196
  )
197
 
198
  logging.info(f"Permission response '{response}'")
199
+ decision = (
200
+ YesNoAsk.YES
201
+ if any(
202
+ affirmative in response.strip().lower()
203
+ for affirmative in [
204
+ "yes",
205
+ "true",
206
+ "ok",
207
+ "okay",
208
+ "1",
209
+ "y",
210
+ "sure",
211
+ "fine",
212
+ "alright",
213
+ ]
214
+ )
215
+ else YesNoAsk.NO
216
+ )
217
 
218
  response = {
219
  "messages": [],
 
325
  def route_is_relevant(
326
  state: TutorialState,
327
  config: RunnableConfig,
328
+ ):
329
  """Route based on whether the query is relevant to Photoshop tutorials.
330
 
331
  Args:
 
333
  config: RunnableConfig for accessing configuration parameters
334
 
335
  Returns:
336
+ dict: State update with 'next' set to the next node name
337
  """
 
 
338
  configurable = Configuration.from_runnable_config(config)
339
  cls = get_chat_api(configurable.llm_api)
340
  logging.info("LLM SELECTED: %s", cls)
 
354
  else:
355
  query = state["query"]
356
 
 
357
  prompt = NODE_PROMPTS["relevance"].format(query=query)
 
358
  relevance = llm.invoke([HumanMessage(content=prompt)])
359
+ next_node = "research" if relevance.decision == "yes" else "write_answer"
360
  answer = (
361
  f"Query is {'not' if relevance.decision == 'no' else ''} "
362
  "relevant to Photoshop."
363
  )
364
+ return {
365
+ "messages": [AIMessage(content=answer)],
366
+ "query": query,
367
+ "next": next_node,
368
+ }
369
 
370
 
371
  class IsComplete(BaseModel):
 
380
  new_query: str = Field(description="Query for additional research.")
381
 
382
 
383
+ def route_is_complete(state: TutorialState, config: RunnableConfig):
 
 
384
  """Route based on whether research is complete or more is needed.
385
 
386
  Args:
 
388
  config: RunnableConfig for accessing configuration parameters
389
 
390
  Returns:
391
+ dict: State update with 'next' set to the next node name
392
  """
 
 
393
  configurable = Configuration.from_runnable_config(config)
394
 
395
  if state["loop_count"] >= int(configurable.max_research_loops):
396
+ return {
397
+ "messages": [
398
+ AIMessage(
399
+ content="Research loop count is too large. Do your best with what you have."
400
+ )
401
+ ],
402
+ "next": "write_answer",
403
+ }
 
 
404
 
405
  cls = get_chat_api(configurable.llm_api)
406
  logging.info("LLM SELECTED: %s", cls)
 
413
  msg.content for msg in state["messages"] if isinstance(msg, AIMessage)
414
  )
415
 
 
416
  prompt = NODE_PROMPTS["completeness"].format(
417
+ query=state["query"],
418
+ responses="\n\n".join(str(m) for m in ai_messages),
419
  )
420
 
421
  completeness = llm.invoke([HumanMessage(content=prompt)])
422
+ next_node = (
423
+ "write_answer" if "yes" in completeness.decision else "research"
424
+ )
425
  decision_message = AIMessage(
426
  content=f"Research completeness: {completeness.decision}"
427
  )
428
+ return {"messages": [decision_message], "next": next_node}
 
 
 
 
429
 
430
 
431
  def write_answer(state: TutorialState, config: RunnableConfig):
 
514
 
515
  graph_builder = StateGraph(TutorialState)
516
 
 
 
517
  graph_builder.add_node(init_state)
518
  graph_builder.add_node(research)
519
  graph_builder.add_node(search_help)
 
521
  "search_rag", functools.partial(search_rag, datastore=datastore)
522
  )
523
  graph_builder.add_node(write_answer)
 
 
 
 
 
 
524
  graph_builder.add_node(route_is_relevant)
525
  graph_builder.add_node(route_is_complete, defer=True)
526
  graph_builder.add_edge(START, init_state.__name__)
 
527
  graph_builder.add_edge(init_state.__name__, route_is_relevant.__name__)
528
  graph_builder.add_edge(research.__name__, search_help.__name__)
529
  graph_builder.add_edge(research.__name__, search_rag.__name__)
530
  graph_builder.add_edge(search_help.__name__, route_is_complete.__name__)
531
  graph_builder.add_edge(search_rag.__name__, route_is_complete.__name__)
 
532
  graph_builder.add_edge(write_answer.__name__, END)
533
 
534
+ # Conditional edges for routing based on 'next' in state
535
+ graph_builder.add_conditional_edges(
536
+ route_is_relevant.__name__,
537
+ lambda state: state["next"],
538
+ {"research": research.__name__, "write_answer": write_answer.__name__},
539
+ )
540
+ graph_builder.add_conditional_edges(
541
+ route_is_complete.__name__,
542
+ lambda state: state["next"],
543
+ {"research": research.__name__, "write_answer": write_answer.__name__},
544
+ )
545
+
546
  return datastore, graph_builder
547
 
548