Doubt-Solver / app.py
ak0601's picture
Update app.py
7243569 verified
raw
history blame
5.94 kB
# from fastapi import FastAPI, Request, Form, UploadFile, File
# from fastapi.templating import Jinja2Templates
# from fastapi.responses import HTMLResponse, RedirectResponse
# from fastapi.staticfiles import StaticFiles
# from dotenv import load_dotenv
# import os, io
# from PIL import Image
# import markdown
# import google.generativeai as genai
# # Load environment variable
# load_dotenv()
# API_KEY = os.getenv("GOOGLE_API_KEY")
# genai.configure(api_key=API_KEY)
# app = FastAPI()
# templates = Jinja2Templates(directory="templates")
# app.mount("/static", StaticFiles(directory="static"), name="static")
# model = genai.GenerativeModel('gemini-2.0-flash')
# # Create a global chat session
# chat = None
# chat_history = []
# @app.get("/", response_class=HTMLResponse)
# async def root(request: Request):
# return templates.TemplateResponse("index.html", {
# "request": request,
# "chat_history": chat_history,
# })
# @app.post("/", response_class=HTMLResponse)
# async def handle_input(
# request: Request,
# user_input: str = Form(...),
# image: UploadFile = File(None)
# ):
# global chat, chat_history
# # Initialize chat session if needed
# if chat is None:
# chat = model.start_chat(history=[])
# parts = []
# if user_input:
# parts.append(user_input)
# # For display in the UI
# user_message = user_input
# if image and image.content_type.startswith("image/"):
# data = await image.read()
# try:
# img = Image.open(io.BytesIO(data))
# parts.append(img)
# user_message += " [Image uploaded]" # Indicate image in chat history
# except Exception as e:
# chat_history.append({
# "role": "model",
# "content": markdown.markdown(f"**Error loading image:** {e}")
# })
# return RedirectResponse("/", status_code=303)
# # Store user message for display
# chat_history.append({"role": "user", "content": user_message})
# try:
# # Send message to Gemini model
# resp = chat.send_message(parts)
# # Add model response to history
# raw = resp.text
# chat_history.append({"role": "model", "content": raw})
# except Exception as e:
# err = f"**Error:** {e}"
# chat_history.append({
# "role": "model",
# "content": markdown.markdown(err)
# })
# # Post-Redirect-Get
# return RedirectResponse("/", status_code=303)
# # Clear chat history and start fresh
# @app.post("/new")
# async def new_chat():
# global chat, chat_history
# chat = None
# chat_history.clear()
# return RedirectResponse("/", status_code=303)
import os
import io
import streamlit as st
from dotenv import load_dotenv
from PIL import Image
import google.generativeai as genai
from langgraph.graph import StateGraph,END
from typing import TypedDict, List, Union
load_dotenv()
API_KEY = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key = API_KEY)
model = genai.GenerativeModel("gemini-2.0-flash")
class ChatState(TypedDict):
user_input: str
image: Union[Image.Image,None]
raw_response:str
final_response:str
chat_history:List[dict]
def input_node(state: ChatState)->ChatState:
return state
def processing_node(state:ChatState) -> ChatState:
parts = [state["user_input"]]
if state["image"]:
parts.append(state["image"])
try:
chat = model.start_chat(history = [])
resp = chat.send_message(parts)
except Exception as e:
state["raw_response"] = f"Error: {e}"
return state
def checking_node(state: ChatState) -> ChatState:
raw = state["raw_response"]
if "Sure!" in raw or "The image shows" in raw or raw.startswith("I can see"):
lines = raw.split("\n")
filtered_lines = [line for line in lines if not line.startswith("Sure!") and "The image shows" not in line]
state["final_response"] = "\n".join(filtered_lines).strip()
else:
state["final_response"] = raw
st.session_state.chat_history.append({"role":"user","content":state["user_input"]})
st.session_state.chat_history.append({"role":"model","content":state["final_response"]})
return state
builder = StateGraph(ChatState)
builder.add_node("input",input_node)
builder.add_node("processing",processing_node)
builder.add_node("checking",checking_node)
builder.set_entry_point("input")
builder.add_edge("input","processing")
builder.add_edge("processing","checking")
builder.add_edge("checking",END)
graph = builder.compile()
st.set_page_config(page_title="Math Chatbot",layout="centered")
st.title("Math Chatbot")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
for msg in st.session_state.chat_history:
with st.chat_message(msg["role"]):
st.markdown(msg["content"])
with st.sidebar:
st.header("Options")
if st.button("New Chat"):
st.session_state.chat_historyb = []
st.rerun()
with st.form("chat_form",clear_on_submit=True):
user_input = st.text_input("Your message:", placeholder="Ask your math problem here")
uploaded_file = st.file_uploader("Upload an image",type = ["jpg","png","jpeg"])
submitted = st.form_submit_button("Send")
if submitted:
image = None
if uploaded_file:
try:
image = Image.open(io.BytesIO(uploaded_file.read()))
except Exception as e:
st.error(f"Error loading image: {e}")
st.stop()
input_state = {
"user_input":user_input,
"image": image,
"raw_response":"",
"final_response":"",
"chat_history":st.session_state.chat_history
}
output = graph.invoke(input_state)
with st.chat_message("model"):
st.markdown(output["final_response"])