ak0601 commited on
Commit
a495cca
·
verified ·
1 Parent(s): 7243569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -41
app.py CHANGED
@@ -90,42 +90,51 @@
90
  # chat_history.clear()
91
  # return RedirectResponse("/", status_code=303)
92
 
93
-
94
  import os
95
  import io
96
  import streamlit as st
97
  from dotenv import load_dotenv
98
  from PIL import Image
99
  import google.generativeai as genai
100
- from langgraph.graph import StateGraph,END
101
  from typing import TypedDict, List, Union
102
 
 
 
 
103
  load_dotenv()
104
  API_KEY = os.getenv("GOOGLE_API_KEY")
105
- genai.configure(api_key = API_KEY)
106
 
107
  model = genai.GenerativeModel("gemini-2.0-flash")
108
 
 
 
 
109
  class ChatState(TypedDict):
110
  user_input: str
111
- image: Union[Image.Image,None]
112
- raw_response:str
113
- final_response:str
114
- chat_history:List[dict]
115
 
116
 
117
- def input_node(state: ChatState)->ChatState:
 
 
 
118
  return state
119
 
120
- def processing_node(state:ChatState) -> ChatState:
 
121
  parts = [state["user_input"]]
122
  if state["image"]:
123
  parts.append(state["image"])
124
 
125
  try:
126
- chat = model.start_chat(history = [])
127
  resp = chat.send_message(parts)
128
-
129
  except Exception as e:
130
  state["raw_response"] = f"Error: {e}"
131
 
@@ -135,54 +144,76 @@ def processing_node(state:ChatState) -> ChatState:
135
  def checking_node(state: ChatState) -> ChatState:
136
  raw = state["raw_response"]
137
 
138
- if "Sure!" in raw or "The image shows" in raw or raw.startswith("I can see"):
 
139
  lines = raw.split("\n")
140
- filtered_lines = [line for line in lines if not line.startswith("Sure!") and "The image shows" not in line]
141
- state["final_response"] = "\n".join(filtered_lines).strip()
 
 
 
 
142
  else:
143
  state["final_response"] = raw
144
 
145
- st.session_state.chat_history.append({"role":"user","content":state["user_input"]})
146
- st.session_state.chat_history.append({"role":"model","content":state["final_response"]})
 
147
 
148
  return state
149
 
 
 
 
 
150
  builder = StateGraph(ChatState)
151
- builder.add_node("input",input_node)
152
- builder.add_node("processing",processing_node)
153
- builder.add_node("checking",checking_node)
154
 
155
  builder.set_entry_point("input")
156
- builder.add_edge("input","processing")
157
- builder.add_edge("processing","checking")
158
- builder.add_edge("checking",END)
159
-
160
 
161
  graph = builder.compile()
162
 
163
- st.set_page_config(page_title="Math Chatbot",layout="centered")
 
 
 
164
  st.title("Math Chatbot")
165
 
 
166
  if "chat_history" not in st.session_state:
167
  st.session_state.chat_history = []
168
-
 
169
  for msg in st.session_state.chat_history:
170
  with st.chat_message(msg["role"]):
171
  st.markdown(msg["content"])
172
 
 
 
 
173
  with st.sidebar:
174
  st.header("Options")
175
  if st.button("New Chat"):
176
- st.session_state.chat_historyb = []
177
  st.rerun()
178
 
179
-
180
- with st.form("chat_form",clear_on_submit=True):
 
 
181
  user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
182
- uploaded_file = st.file_uploader("Upload an image",type = ["jpg","png","jpeg"])
 
 
183
  submitted = st.form_submit_button("Send")
184
 
185
  if submitted:
 
186
  image = None
187
  if uploaded_file:
188
  try:
@@ -191,15 +222,18 @@ with st.form("chat_form",clear_on_submit=True):
191
  st.error(f"Error loading image: {e}")
192
  st.stop()
193
 
194
- input_state = {
195
- "user_input":user_input,
196
- "image": image,
197
- "raw_response":"",
198
- "final_response":"",
199
- "chat_history":st.session_state.chat_history
200
- }
201
-
202
- output = graph.invoke(input_state)
203
-
204
- with st.chat_message("model"):
205
- st.markdown(output["final_response"])
 
 
 
 
90
  # chat_history.clear()
91
  # return RedirectResponse("/", status_code=303)
92
 
 
93
  import os
94
  import io
95
  import streamlit as st
96
  from dotenv import load_dotenv
97
  from PIL import Image
98
  import google.generativeai as genai
99
+ from langgraph.graph import StateGraph, END
100
  from typing import TypedDict, List, Union
101
 
102
+ # ---------------------------
103
+ # Load API Key
104
+ # ---------------------------
105
  load_dotenv()
106
  API_KEY = os.getenv("GOOGLE_API_KEY")
107
+ genai.configure(api_key=API_KEY)
108
 
109
  model = genai.GenerativeModel("gemini-2.0-flash")
110
 
111
+ # ---------------------------
112
+ # State Definition
113
+ # ---------------------------
114
  class ChatState(TypedDict):
115
  user_input: str
116
+ image: Union[Image.Image, None]
117
+ raw_response: str
118
+ final_response: str
119
+ chat_history: List[dict]
120
 
121
 
122
+ # ---------------------------
123
+ # LangGraph Nodes
124
+ # ---------------------------
125
+ def input_node(state: ChatState) -> ChatState:
126
  return state
127
 
128
+
129
+ def processing_node(state: ChatState) -> ChatState:
130
  parts = [state["user_input"]]
131
  if state["image"]:
132
  parts.append(state["image"])
133
 
134
  try:
135
+ chat = model.start_chat(history=[])
136
  resp = chat.send_message(parts)
137
+ state["raw_response"] = resp.text
138
  except Exception as e:
139
  state["raw_response"] = f"Error: {e}"
140
 
 
144
  def checking_node(state: ChatState) -> ChatState:
145
  raw = state["raw_response"]
146
 
147
+ # Remove unnecessary lines from Gemini responses
148
+ if raw.startswith("Sure!") or "The image shows" in raw:
149
  lines = raw.split("\n")
150
+ filtered = [
151
+ line for line in lines
152
+ if not line.startswith("Sure!") and "The image shows" not in line
153
+ ]
154
+ final = "\n".join(filtered).strip()
155
+ state["final_response"] = final
156
  else:
157
  state["final_response"] = raw
158
 
159
+ # Save to session chat history
160
+ st.session_state.chat_history.append({"role": "user", "content": state["user_input"]})
161
+ st.session_state.chat_history.append({"role": "model", "content": state["final_response"]})
162
 
163
  return state
164
 
165
+
166
+ # ---------------------------
167
+ # Build the LangGraph
168
+ # ---------------------------
169
  builder = StateGraph(ChatState)
170
+ builder.add_node("input", input_node)
171
+ builder.add_node("processing", processing_node)
172
+ builder.add_node("checking", checking_node)
173
 
174
  builder.set_entry_point("input")
175
+ builder.add_edge("input", "processing")
176
+ builder.add_edge("processing", "checking")
177
+ builder.add_edge("checking", END)
 
178
 
179
  graph = builder.compile()
180
 
181
+ # ---------------------------
182
+ # Streamlit UI Setup
183
+ # ---------------------------
184
+ st.set_page_config(page_title="Math Chatbot", layout="centered")
185
  st.title("Math Chatbot")
186
 
187
+ # Initialize session state
188
  if "chat_history" not in st.session_state:
189
  st.session_state.chat_history = []
190
+
191
+ # Display chat history
192
  for msg in st.session_state.chat_history:
193
  with st.chat_message(msg["role"]):
194
  st.markdown(msg["content"])
195
 
196
+ # ---------------------------
197
+ # Sidebar
198
+ # ---------------------------
199
  with st.sidebar:
200
  st.header("Options")
201
  if st.button("New Chat"):
202
+ st.session_state.chat_history = []
203
  st.rerun()
204
 
205
+ # ---------------------------
206
+ # Chat Input Form
207
+ # ---------------------------
208
+ with st.form("chat_form", clear_on_submit=True):
209
  user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
210
+
211
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
212
+
213
  submitted = st.form_submit_button("Send")
214
 
215
  if submitted:
216
+ # Load image safely
217
  image = None
218
  if uploaded_file:
219
  try:
 
222
  st.error(f"Error loading image: {e}")
223
  st.stop()
224
 
225
+ # Prepare state
226
+ input_state = {
227
+ "user_input": user_input,
228
+ "image": image,
229
+ "raw_response": "",
230
+ "final_response": "",
231
+ "chat_history": st.session_state.chat_history,
232
+ }
233
+
234
+ # Run LangGraph
235
+ output = graph.invoke(input_state)
236
+
237
+ # Show model response
238
+ with st.chat_message("model"):
239
+ st.markdown(output["final_response"])