Spaces:
Sleeping
Sleeping
import gradio as gr | |
import cohere | |
import os | |
import re | |
import uuid | |
import time | |
from threading import Timer | |
from dotenv import load_dotenv | |
load_dotenv() | |
COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
co = cohere.Client(COHERE_API_KEY) | |
GIVEN_DIET = "" | |
custom_css = """ | |
.gradio-container { | |
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; | |
} | |
.app-header { | |
text-align: center; | |
margin-bottom: 20px; | |
} | |
.app-title { | |
background: linear-gradient(90deg, #2e7d32, #1976d2); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
font-weight: 700; | |
font-size: 2.5em; | |
margin-bottom: 8px; | |
} | |
.markdown-text p { | |
margin-bottom: 12px; | |
line-height: 1.5; | |
} | |
.diet-header { | |
font-size: 1.5em; | |
margin-top: 24px; | |
margin-bottom: 16px; | |
font-weight: 600; | |
color: #2e7d32; | |
border-bottom: 1px solid rgba(0, 0, 0, 0.1); | |
padding-bottom: 6px; | |
} | |
.day-header { | |
font-size: 1.2em; | |
margin-top: 20px; | |
margin-bottom: 12px; | |
font-weight: 600; | |
color: #1976d2; | |
background-color: #f5f5f5; | |
padding: 8px 12px; | |
border-radius: 6px; | |
} | |
.meal-type { | |
font-weight: 600; | |
color: #d32f2f; | |
margin-right: 4px; | |
} | |
.markdown-text ul, .markdown-text ol { | |
margin-bottom: 12px; | |
padding-left: 24px; | |
} | |
.markdown-text li { | |
margin-bottom: 4px; | |
} | |
.markdown-text strong { | |
font-weight: 600; | |
} | |
.message-bubble { | |
padding: 16px; | |
border-radius: 12px; | |
max-width: 90%; | |
margin-bottom: 8px; | |
} | |
.user-message { | |
background-color: #e3f2fd; | |
align-self: flex-end; | |
} | |
.bot-message { | |
background-color: #f1f8e9; | |
align-self: flex-start; | |
} | |
.center-content { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
} | |
.diet-card { | |
border: 1px solid #e0e0e0; | |
border-radius: 12px; | |
padding: 20px; | |
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05); | |
background-color: white; | |
margin-top: 10px; | |
margin-bottom: 20px; | |
width: 100%; | |
line-height: 1.6; | |
} | |
.diet-card ul li::marker { | |
color: #2e7d32; | |
} | |
.diet-card p { | |
margin-bottom: 8px; | |
} | |
.diet-list { | |
list-style-type: none; | |
padding-left: 0; | |
} | |
.diet-list li { | |
background-color: #f9f9f9; | |
margin-bottom: 6px; | |
padding: 6px 10px; | |
border-radius: 4px; | |
border-left: 3px solid #2e7d32; | |
} | |
.input-container { | |
background-color: white; | |
border-radius: 10px; | |
padding: 10px; | |
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05); | |
} | |
.features-container { | |
display: flex; | |
flex-wrap: wrap; | |
margin-top: 24px; | |
gap: 16px; | |
} | |
.feature-card { | |
flex: 1 1 calc(33% - 16px); | |
min-width: 200px; | |
background-color: white; | |
padding: 16px; | |
border-radius: 10px; | |
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.05); | |
} | |
.feature-icon { | |
font-size: 24px; | |
margin-bottom: 10px; | |
color: #2e7d32; | |
} | |
.footer { | |
text-align: center; | |
margin-top: 30px; | |
font-size: 0.9em; | |
color: #666; | |
} | |
""" | |
def format_markdown(text): | |
if "## Diet Overview" in text: | |
text = text.replace("## Diet Overview", '<h2 class="diet-header">Diet Overview</h2>') | |
text = text.replace("## Weekly Meal Plan", '<h2 class="diet-header">Weekly Meal Plan</h2>') | |
text = text.replace("## General Recommendations", '<h2 class="diet-header">General Recommendations</h2>') | |
for day in ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]: | |
text = text.replace(f"### {day}", f'<h3 class="day-header">{day}</h3>') | |
for meal in ["Breakfast", "Lunch", "Dinner", "Snacks"]: | |
text = text.replace(f"**{meal}**:", f'<strong class="meal-type">{meal}</strong>:') | |
lines = text.split('\n') | |
formatted_lines = [] | |
in_list = False | |
list_items = [] | |
for line in lines: | |
if line.strip().startswith('- '): | |
in_list = True | |
list_items.append(line.strip()[2:]) | |
else: | |
if in_list: | |
formatted_list = '<ul class="diet-list">\n' | |
for item in list_items: | |
formatted_list += f' <li>{item}</li>\n' | |
formatted_list += '</ul>' | |
formatted_lines.append(formatted_list) | |
in_list = False | |
list_items = [] | |
formatted_lines.append(line) | |
if in_list: | |
formatted_list = '<ul class="diet-list">\n' | |
for item in list_items: | |
formatted_list += f' <li>{item}</li>\n' | |
formatted_list += '</ul>' | |
formatted_lines.append(formatted_list) | |
text = '\n'.join(formatted_lines) | |
text = f'<div class="diet-card markdown-text">{text}</div>' | |
return text | |
class ConversationState: | |
def __init__(self): | |
self.reset() | |
self.given_diet = "" | |
def reset(self): | |
self.current_step = "greeting" | |
self.user_profile = { | |
"dietary_preference": None, | |
"health_goal": None, | |
"age": None, | |
"weight": None, | |
"height": None, | |
"gender": None, | |
"activity_level": None, | |
"restrictions": [], | |
"allergies": [] | |
} | |
self.given_diet = "" | |
session_states = {} | |
session_last_activity = {} | |
SESSION_TIMEOUT = 360 | |
def cleanup_sessions(): | |
current_time = time.time() | |
inactive_sessions = [] | |
for session_id, last_activity in session_last_activity.items(): | |
if current_time - last_activity > SESSION_TIMEOUT: | |
inactive_sessions.append(session_id) | |
for session_id in inactive_sessions: | |
if session_id in session_states: | |
del session_states[session_id] | |
if session_id in session_last_activity: | |
del session_last_activity[session_id] | |
Timer(300, cleanup_sessions).start() | |
Timer(300, cleanup_sessions).start() | |
def get_session_state(session_id): | |
if session_id not in session_states: | |
session_states[session_id] = ConversationState() | |
session_last_activity[session_id] = time.time() | |
return session_states[session_id] | |
def extract_info(message): | |
"""Extract all relevant information from a message at once""" | |
info = {} | |
preferences = ["vegetarian", "vegan", "keto", "paleo", "gluten-free", "low-carb", | |
"pescatarian", "mediterranean", "dash", "plant-based", "flexitarian", | |
"low-fat", "carnivore", "whole food"] | |
for pref in preferences: | |
if pref in message.lower(): | |
info["dietary_preference"] = pref | |
break | |
goals = { | |
"lose weight": "weight loss", | |
"weight loss": "weight loss", | |
"gain weight": "weight gain", | |
"weight gain": "weight gain", | |
"maintain": "weight maintenance", | |
"muscle": "muscle building", | |
"energy": "improved energy", | |
"health": "better health" | |
} | |
for key, goal in goals.items(): | |
if key in message.lower(): | |
info["health_goal"] = goal | |
break | |
age_match = re.search(r'(\d+)\s*(?:years|year|yr|y)(?:\s*old)?', message.lower()) | |
if age_match: | |
info["age"] = age_match.group(1) | |
weight_match = re.search(r'(\d+(?:\.\d+)?)\s*(?:kg|kilos?|pounds?|lbs?)', message.lower()) | |
if weight_match: | |
info["weight"] = weight_match.group(0) | |
height_match = re.search(r'(\d+(?:\.\d+)?)\s*(?:cm|centimeters?|meters?|m|feet|foot|ft|\'|inches|in|")', message.lower()) | |
if height_match: | |
info["height"] = height_match.group(0) | |
if "male" in message.lower(): | |
info["gender"] = "male" | |
elif "female" in message.lower(): | |
info["gender"] = "female" | |
activity_keywords = { | |
"sedentary": ["sedentary", "inactive", "not active", "desk job"], | |
"lightly active": ["lightly active", "light activity", "light exercise", "walk"], | |
"moderately active": ["moderately active", "moderate activity", "moderate exercise"], | |
"very active": ["very active", "active", "highly active", "exercise regularly"] | |
} | |
for level, keywords in activity_keywords.items(): | |
if any(keyword in message.lower() for keyword in keywords): | |
info["activity_level"] = level | |
break | |
restrictions = ["lactose intolerance", "high sodium", "added sugar", "low fodmap", | |
"low carb", "celiac", "gluten", "high cholesterol", "ibs"] | |
found_restrictions = [r for r in restrictions if r in message.lower()] | |
if found_restrictions: | |
info["restrictions"] = found_restrictions | |
allergies = ["nuts", "shellfish", "eggs", "dairy", "wheat", "soy", "fish", | |
"peanuts"] | |
found_allergies = [a for a in allergies if a in message.lower()] | |
if found_allergies: | |
info["allergies"] = found_allergies | |
return info | |
def chat_with_diet_bot(message, history, session_id=None): | |
if not message.strip(): | |
return "", history | |
if session_id is None: | |
session_id = str(uuid.uuid4()) | |
state = get_session_state(session_id) | |
response = "" | |
if any(word in message.lower() for word in ["reset", "start over"]): | |
state.reset() | |
response = "I've reset your information. Would you like to create a diet plan?" | |
state.current_step = "request_diet_plan" | |
return "", history + [[message, response]], session_id | |
if any(word in message.lower() for word in ["goodbye", "bye", "exit"]): | |
state.reset() | |
response = "Goodbye! Feel free to come back anytime for diet advice." | |
return "", history + [[message, response]], session_id | |
info = extract_info(message) | |
for key, value in info.items(): | |
if key == "restrictions" or key == "allergies": | |
state.user_profile[key] = value | |
elif value: | |
state.user_profile[key] = value | |
if state.current_step == "greeting": | |
if any(word in message.lower() for word in ["hello", "hi", "hey", "diet", "plan", "meal"]): | |
response = "Hello! I'm your diet planning assistant. Would you like me to create a personalized diet plan for you?" | |
state.current_step = "request_diet_plan" | |
else: | |
response = "Hello! I can help create a personalized meal plan based on your preferences and goals. Would you like me to help with that?" | |
state.current_step = "request_diet_plan" | |
elif state.current_step == "request_diet_plan": | |
if any(word in message.lower() for word in ["yes", "sure", "okay", "please", "diet", "plan"]): | |
response = "Great! What type of diet are you following? (e.g., vegetarian, vegan, keto, paleo, gluten-free, or no specific diet)" | |
state.current_step = "ask_dietary_preference" | |
else: | |
response = "I'm here to help with diet planning whenever you're ready." | |
state.current_step = "greeting" | |
elif state.current_step == "ask_dietary_preference": | |
if state.user_profile["dietary_preference"]: | |
response = f"Thanks for letting me know you follow a {state.user_profile['dietary_preference']} diet. What's your health goal? (e.g., weight loss, weight gain, maintenance, muscle building)" | |
else: | |
state.user_profile["dietary_preference"] = "standard balanced" | |
response = "I'll focus on a standard balanced diet plan. What's your health goal? (e.g., weight loss, weight gain, maintenance, muscle building)" | |
state.current_step = "ask_health_goal" | |
elif state.current_step == "ask_health_goal": | |
if state.user_profile["health_goal"]: | |
response = "Could you share your age, weight, height, gender, and activity level? For example: '35 years old, 70kg, 175cm, male, moderate activity'" | |
else: | |
state.user_profile["health_goal"] = "general health" | |
response = "I'll focus on general health improvements. Could you share your age, weight, height, gender, and activity level?" | |
state.current_step = "ask_personal_info" | |
elif state.current_step == "ask_personal_info": | |
personal_info = ["age", "weight", "height", "gender", "activity_level"] | |
missing = [item for item in personal_info if not state.user_profile[item]] | |
if missing: | |
response = f"Thanks for that information. I still need your {', '.join(missing)}. Could you provide that?" | |
else: | |
response = "Do you have any dietary restrictions I should be aware of? (e.g., lactose intolerance, low sodium, low sugar)" | |
state.current_step = "ask_restrictions" | |
elif state.current_step == "ask_restrictions": | |
if "no" in message.lower() or "none" in message.lower(): | |
state.user_profile["restrictions"] = [] | |
response = "Do you have any food allergies? (e.g., nuts, shellfish, eggs, dairy)" | |
state.current_step = "ask_allergies" | |
elif state.current_step == "ask_allergies": | |
if "no" in message.lower() or "none" in message.lower(): | |
state.user_profile["allergies"] = [] | |
profile = state.user_profile | |
response = ( | |
f"I've collected the following information:\n" | |
f"Dietary preference: {profile['dietary_preference']}\n" | |
f"Health goal: {profile['health_goal']}\n" | |
f"Age: {profile['age']}\n" | |
f"Weight: {profile['weight']}\n" | |
f"Height: {profile['height']}\n" | |
f"Gender: {profile['gender']}\n" | |
f"Activity level: {profile['activity_level']}\n" | |
f"Restrictions: {', '.join(profile['restrictions']) if profile['restrictions'] else 'None'}\n" | |
f"Allergies: {', '.join(profile['allergies']) if profile['allergies'] else 'None'}\n\n" | |
f"Is this information correct? I'll use it to create your personalized diet plan." | |
) | |
state.current_step = "confirm_information" | |
elif state.current_step == "confirm_information": | |
if any(word in message.lower() for word in ["yes", "correct", "right", "good", "ok", "okay", "fine", "sure"]): | |
response = "Thanks for confirming! I'm now generating your personalized diet plan..." | |
state.current_step = "generate_diet_plan" | |
profile = state.user_profile | |
prompt = f""" | |
Create a personalized weekly meal plan based on the following user information: | |
Dietary Preference: {profile['dietary_preference']} | |
Health Goal: {profile['health_goal']} | |
Age: {profile['age']} | |
Weight: {profile['weight']} | |
Height: {profile['height']} | |
Gender: {profile['gender']} | |
Activity Level: {profile['activity_level']} | |
Dietary Restrictions: {', '.join(profile['restrictions']) if profile['restrictions'] else 'None'} | |
Allergies: {', '.join(profile['allergies']) if profile['allergies'] else 'None'} | |
Generate a complete meal plan with: | |
1. Diet Overview - Brief summary of the diet approach | |
2. Weekly Meal Plan - Day-by-day plan with breakfast, lunch, dinner, and snacks (sometimes). | |
3. General Recommendations - Additional advice and nutrition tips | |
Format in Markdown with clear headers and bullet points. | |
""" | |
try: | |
cohere_response = co.generate( | |
prompt=prompt, | |
model="command", | |
max_tokens=2000, | |
temperature=0.7, | |
).generations[0].text | |
formatted_response = format_markdown(cohere_response) | |
state.given_diet = formatted_response | |
response = formatted_response + "\n\nI hope this meal plan helps you achieve your goals! Feel free to ask if you have any nutrition questions." | |
state.current_step = "diet_questions" | |
except Exception: | |
response = "I'm sorry, I couldn't generate a diet plan at the moment. Please try again later." | |
state.current_step = "greeting" | |
else: | |
response = "Let's update your information. What would you like to change?" | |
state.current_step = "update_information" | |
elif state.current_step == "update_information": | |
profile = state.user_profile | |
response = ( | |
f"I've updated your information:\n" | |
f"Dietary preference: {profile['dietary_preference']}\n" | |
f"Health goal: {profile['health_goal']}\n" | |
f"Age: {profile['age']}\n" | |
f"Weight: {profile['weight']}\n" | |
f"Height: {profile['height']}\n" | |
f"Gender: {profile['gender']}\n" | |
f"Activity level: {profile['activity_level']}\n" | |
f"Restrictions: {', '.join(profile['restrictions']) if profile['restrictions'] else 'None'}\n" | |
f"Allergies: {', '.join(profile['allergies']) if profile['allergies'] else 'None'}\n\n" | |
f"Is this information correct now?" | |
) | |
state.current_step = "confirm_information" | |
elif state.current_step == "diet_questions": | |
prompt = f""" | |
Answer this nutrition question clearly and accurately: | |
Question: {message} | |
Previous given diet: {state.given_diet} | |
Provide helpful, evidence-based advice with practical tips. | |
""" | |
try: | |
cohere_response = co.generate( | |
prompt=prompt, | |
model="command", | |
max_tokens=800, | |
temperature=0.7, | |
).generations[0].text | |
response = cohere_response | |
except Exception: | |
response = "I'm sorry, I couldn't answer your question at the moment. Please try again later." | |
if not response: | |
response = "I'm sorry, I didn't understand. Could you please rephrase your question?" | |
return "", history + [[message, response]], session_id | |
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(): | |
session_id = gr.State(value=lambda: str(uuid.uuid4())) | |
gr.HTML(""" | |
<div class="app-header"> | |
<h1 class="app-title">Diet Planning Chatbot</h1> | |
</div> | |
""") | |
gr.Markdown( | |
""" | |
Hello! I can help you create a weekly meal plan based on your dietary preferences, health goals, and personal information. | |
### How to use: | |
1. Start by saying hello or asking for a diet plan | |
2. The chatbot will ask about your dietary needs, health goals, and other informations | |
3. After confirming, it will generate a personalized weekly meal plan | |
""" | |
) | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
avatar_images=["https://api.dicebear.com/7.x/thumbs/svg?seed=user", "https://api.dicebear.com/7.x/thumbs/svg?seed=bot&backgroundColor=2e7d32"], | |
height=500, | |
render_markdown=True | |
) | |
with gr.Row(elem_classes="input-container"): | |
msg = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here...", | |
scale=9, | |
) | |
clear = gr.Button("Clear", scale=1) | |
msg.submit(chat_with_diet_bot, [msg, chatbot, session_id], [msg, chatbot, session_id]) | |
def clear_chat(): | |
new_session_id = str(uuid.uuid4()) | |
return [], new_session_id | |
clear.click(clear_chat, None, [chatbot, session_id]) | |
def initial_greeting(): | |
new_session_id = str(uuid.uuid4()) | |
return [], new_session_id | |
demo.load(initial_greeting, None, [chatbot, session_id]) | |
gr.HTML(""" | |
<div class="features-container"> | |
<div class="feature-card"> | |
<h3>Weekly Meal Plans</h3> | |
<p>Complete 7-day meal plans with breakfast, lunch, dinner, and snacks.</p> | |
</div> | |
<div class="feature-card"> | |
<h3>Personalized</h3> | |
<p>Tailored to your specific dietary needs and health goals.</p> | |
</div> | |
<div class="feature-card"> | |
<h3>Dietary Restrictions</h3> | |
<p>Support for vegetarian, vegan, keto, gluten-free and more.</p> | |
</div> | |
</div> | |
""") | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0") |