Spaces:
Sleeping
Sleeping
from typing import Any, Text, Dict, List, Optional | |
from rasa_sdk import Action, Tracker, FormValidationAction | |
from rasa_sdk.executor import CollectingDispatcher | |
from rasa_sdk.events import SlotSet, ActiveLoop | |
import cohere | |
import os | |
from dotenv import load_dotenv | |
import json | |
import markdown | |
import re | |
load_dotenv() | |
COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
co = cohere.Client(COHERE_API_KEY) | |
def format_user_profile(tracker): | |
dietary_preference = tracker.get_slot("dietary_preference") or "not specified" | |
health_goal = tracker.get_slot("health_goal") or "not specified" | |
age = tracker.get_slot("age") or "not specified" | |
weight = tracker.get_slot("weight") or "not specified" | |
height = tracker.get_slot("height") or "not specified" | |
gender = tracker.get_slot("gender") or "not specified" | |
activity_level = tracker.get_slot("activity_level") or "not specified" | |
restrictions = tracker.get_slot("restrictions") or [] | |
allergies = tracker.get_slot("allergies") or [] | |
profile = { | |
"dietary_preference": dietary_preference, | |
"health_goal": health_goal, | |
"age": age, | |
"weight": weight, | |
"height": height, | |
"gender": gender, | |
"activity_level": activity_level, | |
"restrictions": restrictions, | |
"allergies": allergies | |
} | |
return profile | |
def format_plan_markdown(plan_dict: Dict[Text, Any]) -> Text: | |
markdown_text = "" | |
if not isinstance(plan_dict, dict): | |
return str(plan_dict) | |
if "overview" in plan_dict: | |
markdown_text += f"## Diet Overview\\n{plan_dict['overview']}\\n\\n" | |
if "weekly_plan" in plan_dict and isinstance(plan_dict["weekly_plan"], dict): | |
markdown_text += "## Weekly Meal Plan\\n" | |
for day, meals in plan_dict["weekly_plan"].items(): | |
markdown_text += f"### {day.capitalize()}\\n" | |
if isinstance(meals, dict): | |
for meal_type, description in meals.items(): | |
markdown_text += f"- **{meal_type.capitalize()}:** {description}\\n" | |
else: | |
markdown_text += f"- {meals}\\n" | |
markdown_text += "\\n" | |
if "recommendations" in plan_dict: | |
markdown_text += f"## General Recommendations\\n{plan_dict['recommendations']}\\n" | |
return markdown_text.strip() | |
class ActionGenerateDietPlan(Action): | |
def name(self) -> Text: | |
return "action_generate_diet_plan" | |
async def run(self, dispatcher: CollectingDispatcher, | |
tracker: Tracker, | |
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]: | |
user_profile = format_user_profile(tracker) | |
prompt = f""" | |
Create a personalized weekly meal plan based on the following user information: | |
Dietary Preference: {user_profile['dietary_preference']} | |
Health Goal: {user_profile['health_goal']} | |
Age: {user_profile['age']} | |
Weight: {user_profile['weight']} | |
Height: {user_profile['height']} | |
Gender: {user_profile['gender']} | |
Activity Level: {user_profile['activity_level']} | |
Dietary Restrictions: {', '.join(user_profile['restrictions']) if user_profile['restrictions'] else 'None'} | |
Allergies: {', '.join(user_profile['allergies']) if user_profile['allergies'] else 'None'} | |
Generate a complete, well-structured meal plan. | |
IMPORTANT: Structure the core meal plan data as a JSON object within a Markdown code block like this: | |
```json | |
{{ | |
"overview": "A brief summary of the diet approach...", | |
"weekly_plan": {{ | |
"Monday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Tuesday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Wednesday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Thursday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Friday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Saturday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}}, | |
"Sunday": {{"breakfast": "...", "lunch": "...", "dinner": "...", "snacks": "..."}} | |
}}, | |
"recommendations": "General tips and advice..." | |
}} | |
``` | |
Ensure the JSON is valid. Provide the JSON structure exactly as specified above, filling in the details. | |
You can add introductory or concluding text outside the JSON block if needed. | |
""" | |
try: | |
response = co.chat( | |
message=prompt, | |
model="command-r", | |
temperature=0.3, | |
) | |
llm_response_text = response.text | |
plan_dict = None | |
json_match = re.search(r"```json\s*(\{.*?\})\s*```", llm_response_text, re.DOTALL) | |
if json_match: | |
json_string = json_match.group(1) | |
try: | |
plan_dict = json.loads(json_string) | |
print("Successfully parsed JSON plan from LLM.") | |
except json.JSONDecodeError as json_err: | |
print(f"JSON Decode Error: {json_err}") | |
print(f"Invalid JSON string: {json_string}") | |
plan_dict = {"error": "Failed to parse plan structure", "raw_text": llm_response_text} | |
else: | |
print("No JSON block found in LLM response.") | |
plan_dict = {"error": "No structured plan found", "raw_text": llm_response_text} | |
formatted_plan_markdown = format_plan_markdown(plan_dict) | |
dispatcher.utter_message(text="Here is your personalized diet plan:\n\n" + formatted_plan_markdown) | |
return [SlotSet("generated_diet_plan", plan_dict), SlotSet("user_profile_complete", True)] | |
except Exception as e: | |
print(f"Error during LLM call or processing: {str(e)}") | |
dispatcher.utter_message(text=f"I'm sorry, I couldn't generate a diet plan at the moment. Error: {str(e)}") | |
return [SlotSet("generated_diet_plan", None), SlotSet("user_profile_complete", True)] | |
class ActionAnswerDietQuestion(Action): | |
def name(self) -> Text: | |
return "action_answer_diet_question" | |
def run(self, dispatcher: CollectingDispatcher, | |
tracker: Tracker, | |
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]: | |
last_message = tracker.latest_message.get("text") | |
prompt = f""" | |
Answer the following nutrition or diet-related question clearly and accurately: | |
Question: {last_message} | |
Provide a helpful, evidence-based response with practical advice when applicable. | |
If it would be beneficial, include a few bullet points with key takeaways. | |
""" | |
try: | |
response = co.generate( | |
prompt=prompt, | |
model="command", | |
max_tokens=800, | |
temperature=0.7, | |
) | |
answer = response.generations[0].text | |
dispatcher.utter_message(text=answer) | |
return [] | |
except Exception as e: | |
dispatcher.utter_message(text=f"I'm sorry, I couldn't answer your question at the moment. Error: {str(e)}") | |
return [] | |
class ActionUpdateDietPlan(Action): | |
def name(self) -> Text: | |
return "action_update_diet_plan" | |
def run(self, dispatcher: CollectingDispatcher, | |
tracker: Tracker, | |
domain: Dict[Text, Any]) -> List[Dict[Text, Any]]: | |
user_request = tracker.latest_message.get("text") | |
user_profile = format_user_profile(tracker) | |
prompt = f""" | |
Update the previously generated meal plan based on the following user feedback: | |
User Profile: | |
Dietary Preference: {user_profile['dietary_preference']} | |
Health Goal: {user_profile['health_goal']} | |
Age: {user_profile['age']} | |
Weight: {user_profile['weight']} | |
Height: {user_profile['height']} | |
Gender: {user_profile['gender']} | |
Activity Level: {user_profile['activity_level']} | |
Dietary Restrictions: {', '.join(user_profile['restrictions']) if user_profile['restrictions'] else 'None'} | |
Allergies: {', '.join(user_profile['allergies']) if user_profile['allergies'] else 'None'} | |
User Feedback: {user_request} | |
Generate an updated meal plan that addresses the user's feedback while still keeping their health goals and dietary needs in mind. | |
Format the response in Markdown with clear headers and bullet points. | |
""" | |
try: | |
response = co.generate( | |
prompt=prompt, | |
model="command", | |
max_tokens=2000, | |
temperature=0.7, | |
) | |
updated_plan = response.generations[0].text | |
dispatcher.utter_message(text=updated_plan) | |
return [] | |
except Exception as e: | |
dispatcher.utter_message(text=f"I'm sorry, I couldn't update the diet plan at the moment. Error: {str(e)}") | |
return [] | |
def parse_new_meals(text: str) -> Dict[str, str]: | |
meals = {"breakfast": "Not specified", "lunch": "Not specified", "dinner": "Not specified", "snacks": "Not specified"} | |
text_lower = text.lower() | |
keywords = ["breakfast", "lunch", "dinner", "snacks"] | |
current_keyword = None | |
current_meal_text = "" | |
for word in text.split(): | |
word_lower = word.lower().strip(':.,;') | |
if word_lower in keywords: | |
if current_keyword: | |
meals[current_keyword] = current_meal_text.strip() | |
current_keyword = word_lower | |
current_meal_text = "" | |
elif current_keyword: | |
current_meal_text += word + " " | |
if current_keyword: | |
meals[current_keyword] = current_meal_text.strip() | |
if all(v == "Not specified" for v in meals.values()) and text.strip(): | |
meals['description'] = text.strip() | |
del meals['breakfast'] | |
del meals['lunch'] | |
del meals['dinner'] | |
del meals['snacks'] | |
elif all(v == "Not specified" for v in meals.values()): | |
return {"description": "No details provided"} | |
else: | |
meals = {k: v for k, v in meals.items() if v != "Not specified"} | |
return meals | |
class ValidateDietChangeForm(FormValidationAction): | |
def name(self) -> Text: | |
return "validate_diet_change_form" | |
async def run( | |
self, | |
dispatcher: CollectingDispatcher, | |
tracker: Tracker, | |
domain: Dict[Text, Any], | |
) -> List[Dict[Text, Any]]: | |
"""Validate form slots and submit action.""" | |
events = await super().run(dispatcher, tracker, domain) | |
if tracker.active_loop and tracker.active_loop.get('name') == 'diet_change_form' and tracker.get_slot('new_meals_text') is not None: | |
generated_plan = tracker.get_slot("generated_diet_plan") | |
day_to_change = tracker.get_slot("change_day") | |
new_meals_text = tracker.get_slot("new_meals_text") | |
if not generated_plan or not isinstance(generated_plan, dict) or 'weekly_plan' not in generated_plan: | |
dispatcher.utter_message(response="utter_no_plan_to_change") | |
return events + [ActiveLoop(None), SlotSet("new_meals_text", None), SlotSet("change_day", None)] | |
if not day_to_change: | |
dispatcher.utter_message(text="Sorry, I don't know which day you want to change.") | |
return events + [ActiveLoop(None), SlotSet("new_meals_text", None)] | |
day_to_change_normalized = day_to_change.capitalize() | |
days_in_plan = list(generated_plan.get("weekly_plan", {}).keys()) | |
actual_day_key = day_to_change_normalized | |
for key in days_in_plan: | |
if key.lower() == day_to_change_normalized.lower(): | |
actual_day_key = key | |
break | |
if actual_day_key not in generated_plan["weekly_plan"]: | |
dispatcher.utter_message(text=f"Sorry, I couldn't find '{day_to_change_normalized}' in the current plan.") | |
return events + [ActiveLoop(None), SlotSet("new_meals_text", None), SlotSet("change_day", None)] | |
parsed_new_meals = parse_new_meals(new_meals_text) | |
updated_plan = generated_plan.copy() | |
if isinstance(updated_plan.get("weekly_plan"), dict): | |
updated_plan["weekly_plan"][actual_day_key] = parsed_new_meals | |
else: | |
dispatcher.utter_message(text="Sorry, there was an issue updating the plan structure.") | |
return events + [ActiveLoop(None), SlotSet("new_meals_text", None), SlotSet("change_day", None)] | |
formatted_updated_plan = format_plan_markdown(updated_plan) | |
dispatcher.utter_message(response="utter_confirm_diet_change", | |
generated_diet_plan=formatted_updated_plan, | |
change_day=actual_day_key) | |
return events + [SlotSet("generated_diet_plan", updated_plan), | |
SlotSet("new_meals_text", None), | |
SlotSet("change_day", None), | |
ActiveLoop(None)] | |
return events | |
class ActionUpdateUserInfo(Action): | |
def name(self) -> Text: | |
return "action_update_user_info" | |
async def run( | |
self, | |
dispatcher: CollectingDispatcher, | |
tracker: Tracker, | |
domain: Dict[Text, Any], | |
) -> List[Dict[Text, Any]]: | |
change_field = tracker.get_slot("change_field") | |
new_value = tracker.get_slot("new_value") | |
events = [] | |
if change_field and new_value: | |
change_field = change_field.lower() | |
if "diet" in change_field or "preference" in change_field: | |
events.append(SlotSet("dietary_preference", new_value)) | |
elif "health" in change_field or "goal" in change_field: | |
events.append(SlotSet("health_goal", new_value)) | |
elif "age" in change_field: | |
events.append(SlotSet("age", new_value)) | |
elif "weight" in change_field: | |
events.append(SlotSet("weight", new_value)) | |
elif "height" in change_field: | |
events.append(SlotSet("height", new_value)) | |
elif "gender" in change_field: | |
events.append(SlotSet("gender", new_value)) | |
elif "activity" in change_field: | |
events.append(SlotSet("activity_level", new_value)) | |
events.append(SlotSet("change_field", None)) | |
events.append(SlotSet("new_value", None)) | |
dispatcher.utter_message(response="utter_confirm_update") | |
elif change_field and not new_value: | |
dispatcher.utter_message(response="utter_ask_new_value") | |
else: | |
dispatcher.utter_message(response="utter_ask_which_field_to_change") | |
return events |